From 0bb999b61c0eb6a0328c4c703d7485b9e0856ccf Mon Sep 17 00:00:00 2001 From: Wang Bojun <105858416+wwbitejotunn@users.noreply.github.com> Date: Thu, 29 Dec 2022 10:29:00 +0800 Subject: [PATCH] fused_attention_op paratmers stop grad support (#49351) * fusedAttenGrad_noGrad * code style fix * add ut * remove unnecessary log --- .../operators/fused/fused_attention_op.cc | 67 ++-- .../operators/fused/fused_attention_op.cu | 17 +- .../unittests/test_fused_attention_op.py | 317 ++++++++++++++++++ 3 files changed, 375 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/fused/fused_attention_op.cc b/paddle/fluid/operators/fused/fused_attention_op.cc index 48d77a7fc9..f25dc393d3 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cc +++ b/paddle/fluid/operators/fused/fused_attention_op.cc @@ -520,31 +520,50 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("OutLinearBias"), ctx->GetInputDim("OutLinearBias")); } - ctx->SetOutputDim(framework::GradVarName("OutLinearW"), - ctx->GetInputDim("OutLinearW")); - ctx->SetOutputDim(framework::GradVarName("QKVW"), ctx->GetInputDim("QKVW")); + if (ctx->HasOutput(framework::GradVarName("OutLinearW"))) { + ctx->SetOutputDim(framework::GradVarName("OutLinearW"), + ctx->GetInputDim("OutLinearW")); + } + if (ctx->HasOutput(framework::GradVarName("QKVW"))) { + ctx->SetOutputDim(framework::GradVarName("QKVW"), + ctx->GetInputDim("QKVW")); + } if (ctx->HasOutput(framework::GradVarName("QKVBias"))) { ctx->SetOutputDim(framework::GradVarName("QKVBias"), ctx->GetInputDim("QKVBias")); } if (ctx->Attrs().Get("pre_layer_norm") == true) { - ctx->SetOutputDim(framework::GradVarName("LnOut"), - ctx->GetInputDim("LnOut")); + if (ctx->HasOutput(framework::GradVarName("LnOut"))) { + ctx->SetOutputDim(framework::GradVarName("LnOut"), + ctx->GetInputDim("LnOut")); + } } else { - ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), - ctx->GetInputDim("BiasDropoutResidualOut")); - } - ctx->SetOutputDim(framework::GradVarName("FMHAOut"), - ctx->GetInputDim("FMHAOut")); - ctx->SetOutputDim(framework::GradVarName("QKTVOut"), - ctx->GetInputDim("QKTVOut")); - ctx->SetOutputDim(framework::GradVarName("TransposeOut2"), - ctx->GetInputDim("TransposeOut2")); - ctx->SetOutputDim(framework::GradVarName("QKOut"), - ctx->GetInputDim("QKOut")); - ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"), - ctx->GetInputDim("SoftmaxOut")); + if (ctx->HasOutput(framework::GradVarName("BiasDropoutResidualOut"))) { + ctx->SetOutputDim(framework::GradVarName("BiasDropoutResidualOut"), + ctx->GetInputDim("BiasDropoutResidualOut")); + } + } + if (ctx->HasOutput(framework::GradVarName("FMHAOut"))) { + ctx->SetOutputDim(framework::GradVarName("FMHAOut"), + ctx->GetInputDim("FMHAOut")); + } + if (ctx->HasOutput(framework::GradVarName("QKTVOut"))) { + ctx->SetOutputDim(framework::GradVarName("QKTVOut"), + ctx->GetInputDim("QKTVOut")); + } + if (ctx->HasOutput(framework::GradVarName("TransposeOut2"))) { + ctx->SetOutputDim(framework::GradVarName("TransposeOut2"), + ctx->GetInputDim("TransposeOut2")); + } + if (ctx->HasOutput(framework::GradVarName("QKOut"))) { + ctx->SetOutputDim(framework::GradVarName("QKOut"), + ctx->GetInputDim("QKOut")); + } + if (ctx->HasOutput(framework::GradVarName("SoftmaxOut"))) { + ctx->SetOutputDim(framework::GradVarName("SoftmaxOut"), + ctx->GetInputDim("SoftmaxOut")); + } if (ctx->HasOutput(framework::GradVarName("AttnDropoutOut"))) { ctx->SetOutputDim(framework::GradVarName("AttnDropoutOut"), ctx->GetInputDim("AttnDropoutOut")); @@ -554,14 +573,18 @@ class FusedAttentionGradOp : public framework::OperatorWithKernel { ctx->SetOutputDim(framework::GradVarName("SrcMaskOut"), ctx->GetInputDim("SrcMaskOut")); } - ctx->SetOutputDim(framework::GradVarName("QKVOut"), - ctx->GetInputDim("QKVOut")); + if (ctx->HasOutput(framework::GradVarName("QKVOut"))) { + ctx->SetOutputDim(framework::GradVarName("QKVOut"), + ctx->GetInputDim("QKVOut")); + } if (ctx->HasOutput(framework::GradVarName("QKVBiasOut"))) { ctx->SetOutputDim(framework::GradVarName("QKVBiasOut"), ctx->GetInputDim("QKVBiasOut")); } - ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), - ctx->GetInputDim("OutLinearOut")); + if (ctx->HasOutput(framework::GradVarName("OutLinearOut"))) { + ctx->SetOutputDim(framework::GradVarName("OutLinearOut"), + ctx->GetInputDim("OutLinearOut")); + } } protected: diff --git a/paddle/fluid/operators/fused/fused_attention_op.cu b/paddle/fluid/operators/fused/fused_attention_op.cu index 559a2afb85..d963e73965 100644 --- a/paddle/fluid/operators/fused/fused_attention_op.cu +++ b/paddle/fluid/operators/fused/fused_attention_op.cu @@ -514,15 +514,24 @@ class FusedAttentionGradKernel : public framework::OpKernel { auto *d_ln_2_bias = ctx.Output(framework::GradVarName("Ln2Bias")); - auto *d_qkv_weight_data = dev_ctx.template Alloc( - d_qkv_weight, d_qkv_weight->numel() * sizeof(T)); + auto *d_qkv_weight_data = + (d_qkv_weight == nullptr) + ? nullptr + : dev_ctx.template Alloc(d_qkv_weight, + d_qkv_weight->numel() * sizeof(T)); + auto *d_qkv_bias_data = (d_qkv_bias == nullptr) ? nullptr : dev_ctx.template Alloc(d_qkv_bias, d_qkv_bias->numel() * sizeof(T)); - auto *d_out_linear_weight_data = dev_ctx.template Alloc( - d_out_linear_weight, d_out_linear_weight->numel() * sizeof(T)); + auto *d_out_linear_weight_data = + (d_out_linear_weight == nullptr) + ? nullptr + : dev_ctx.template Alloc( + d_out_linear_weight, + d_out_linear_weight->numel() * sizeof(T)); + auto *d_out_linear_bias_data = (d_out_linear_bias == nullptr) ? nullptr diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index 47296e48a2..d09d0b4fbe 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -390,5 +390,322 @@ class TestFusedAttentionOpCacheKV(TestFusedAttentionOp): ) +class TestFusedAttentionOpParamStopGradient(OpTest): + def setUp(self): + self.config() + self.generate_input_data() + + self.rtol = 1e-5 + # FIXME(limin29): Because there is a problem with the test precision + # on A100, atol is temporarily set to 1e-2, and it will be + # changed back after the precision problem is solved. + self.atol = 1e-2 + # make sure local development precision + if "V100" in paddle.device.cuda.get_device_name(): + self.atol = 1e-4 + if self.x_type is np.float16: + self.atol = 1e-1 + + paddle.set_default_dtype(self.x_type) + self.__class__.op_type = "fused_attention" + # use autograd to check grad in this unittest. + self.__class__.no_need_check_grad = True + self.q_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr, + ) + self.k_proj = Linear( + self.kdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr, + ) + self.v_proj = Linear( + self.vdim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr, + ) + self.out_proj = Linear( + self.embed_dim, + self.embed_dim, + self.weight_attr, + bias_attr=self.bias_attr, + ) + paddle.set_default_dtype(np.float32) + self.norm1 = LayerNorm(self.embed_dim) + self.norm2 = LayerNorm(self.embed_dim) + paddle.set_default_dtype(self.x_type) + self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train") + + def config(self): + self.x_type = np.float32 + self.attn_mask_type = np.float64 + self.pre_layer_norm = False + self.has_attn_mask = True + self.has_cache_kv = False + self.training = True + + self.batch_size = 8 + self.query_length = 128 + self.cache_length = 128 + self.head_dim = 64 + self.num_heads = 16 + self.embed_dim = self.head_dim * self.num_heads + + self.dropout_prob = 0.0 + self.attn_dropout_prob = 0.0 + self.weight_attr = None + self.bias_attr = None + self.kdim, self.vdim = self.embed_dim, self.embed_dim + self.key_length, self.value_length = ( + self.query_length, + self.query_length, + ) + + def generate_input_data(self): + self.query = np.random.rand( + self.batch_size, self.query_length, self.embed_dim + ).astype(self.x_type) + out_seq_len = self.key_length + if self.has_cache_kv: + assert self.training is False, ValueError( + 'cache_kv can only used in inference' + ) + self.cache_kv = np.random.rand( + 2, + self.batch_size, + self.num_heads, + self.cache_length, + self.head_dim, + ).astype(self.x_type) + out_seq_len += self.cache_length + else: + self.cache_kv = None + + if self.has_attn_mask: + # [B, n_head, seq_len, out_seq_len] + self.attn_mask = np.ones( + ( + self.batch_size, + self.num_heads, + self.query_length, + out_seq_len, + ), + dtype=self.attn_mask_type, + ) + if self.attn_mask_type == np.int64: + self.attn_mask = np.tril(self.attn_mask) + elif self.attn_mask_type == np.float64: + self.attn_mask = (np.tril(self.attn_mask) - 1.0) * 1e9 + else: + raise ValueError( + "'attn_mask_type' should be 'int64' or 'float64'." + ) + else: + self.attn_mask = None + self.key, self.value = self.query, self.query + + self.dout = np.random.random( + (self.batch_size, self.query_length, self.embed_dim) + ).astype(self.x_type) + + def GetBaselineOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + tensor_query = paddle.to_tensor(self.query, stop_gradient=False) + + cache_kv = None + if self.has_cache_kv: + cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) + + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None + residual = tensor_query + + ln1_out = tensor_query + if self.pre_layer_norm: + ln1_out = self.norm1(tensor_query) + + q = self.q_proj(ln1_out) + q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim]) + q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3]) + k = self.k_proj(ln1_out) + v = self.v_proj(ln1_out) + k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim]) + k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3]) + v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim]) + v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3]) + + if self.has_cache_kv: + # [1, B, n_head, cache_seq_len, head_dim] + cache_k, cache_v = paddle.split(cache_kv, 2) + cache_k = paddle.squeeze(cache_k, axis=0) + cache_v = paddle.squeeze(cache_v, axis=0) + # [B, n_head, cache_seq_len + seq_len, head_dim] + # out_seq_len = cache_seq_len + seq_len + k_out = paddle.concat([cache_k, k_out], axis=-2) + v_out = paddle.concat([cache_v, v_out], axis=-2) + + # [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, out_seq_len] + qk_out = paddle.matmul(x=q_out, y=k_out, transpose_y=True) + qk_out = paddle.scale(qk_out, scale=self.head_dim**-0.5) + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype) + attn_mask_out = qk_out + attn_mask + softmax_out = F.softmax(attn_mask_out) + else: + softmax_out = F.softmax(qk_out) + + if self.dropout_prob: + dropout_out = F.dropout( + softmax_out, + self.dropout_prob, + training=self.training, + mode="upscale_in_train", + ) + # [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim] + # --> [B, n_head, seq_len, head_dim] + qktv_out = tensor.matmul(dropout_out, v_out) + else: + qktv_out = tensor.matmul(softmax_out, v_out) + + fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3]) + out_linear_in = tensor.reshape( + x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]] + ) + out = self.out_proj(out_linear_in) + + residual_out = residual + self.dropout(out) + if not self.pre_layer_norm: + final_out = self.norm1(residual_out) + else: + final_out = residual_out + + if self.has_cache_kv: + return final_out + + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True + ) + return final_out, tensor_query.grad + + def GetFusedAttentionOut(self): + paddle.disable_static(place=paddle.CUDAPlace(0)) + q_proj_weight = paddle.to_tensor( + self.q_proj.weight, stop_gradient=False + ) + k_proj_weight = paddle.to_tensor( + self.k_proj.weight, stop_gradient=False + ) + v_proj_weight = paddle.to_tensor( + self.v_proj.weight, stop_gradient=False + ) + out_linear_weight = paddle.to_tensor( + self.out_proj.weight, stop_gradient=False + ) + + if self.bias_attr is False: + qkv_bias_tensor = None + out_linear_bias = None + else: + q_proj_bias = paddle.to_tensor( + self.q_proj.bias, stop_gradient=False + ) + k_proj_bias = paddle.to_tensor( + self.k_proj.bias, stop_gradient=False + ) + v_proj_bias = paddle.to_tensor( + self.v_proj.bias, stop_gradient=False + ) + qkv_bias = np.concatenate( + (q_proj_bias.numpy(), k_proj_bias.numpy(), v_proj_bias.numpy()) + ) + qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim)) + qkv_bias_tensor = paddle.to_tensor(qkv_bias, stop_gradient=False) + out_linear_bias = paddle.to_tensor( + self.out_proj.bias, stop_gradient=False + ) + + ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False) + ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False) + ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False) + ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False) + + q_proj_weight = q_proj_weight.numpy().transpose((1, 0)) + k_proj_weight = k_proj_weight.numpy().transpose((1, 0)) + v_proj_weight = v_proj_weight.numpy().transpose((1, 0)) + qkv_weight = np.concatenate( + (q_proj_weight, k_proj_weight, v_proj_weight) + ) + qkv_weight = qkv_weight.reshape( + (3, self.num_heads, self.head_dim, self.embed_dim) + ) + + x = paddle.to_tensor(self.query, stop_gradient=False) + cache_kv = None + if self.has_cache_kv: + cache_kv = paddle.to_tensor(self.cache_kv, stop_gradient=False) + if self.has_attn_mask: + attn_mask = paddle.to_tensor(self.attn_mask, stop_gradient=False) + else: + attn_mask = None + qkv_weight_tensor = paddle.to_tensor(qkv_weight, stop_gradient=False) + epsilon = 1e-05 + ln2_epsilon = 1e-05 + + if attn_mask is not None: + attn_mask = _convert_attention_mask(attn_mask, x.dtype) + qkv_weight_tensor.stop_gradient = True + out_linear_weight.stop_gradient = True + ln1_scale.stop_gradient = True + ln1_bias.stop_gradient = True + ln2_scale.stop_gradient = True + ln2_bias.stop_gradient = True + qkv_bias_tensor.stop_gradient = True + out_linear_bias.stop_gradient = True + final_out = incubate_f.fused_multi_head_attention( + x, + qkv_weight_tensor, + out_linear_weight, + self.pre_layer_norm, + ln1_scale, + ln1_bias, + ln2_scale, + ln2_bias, + epsilon, + qkv_bias_tensor, + out_linear_bias, + cache_kv, + attn_mask, + self.dropout_prob, + self.attn_dropout_prob, + ln2_epsilon, + ) + + if self.has_cache_kv: + return final_out[0], final_out[1] + + paddle.autograd.backward( + [final_out], [paddle.to_tensor(self.dout)], retain_graph=True + ) + return final_out, x.grad + + def test_fused_attention_op(self): + final_out_ref, x_grad_ref = self.GetBaselineOut() + final_out, x_grad = self.GetFusedAttentionOut() + np.testing.assert_allclose( + final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol + ) + np.testing.assert_allclose( + x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol + ) + + if __name__ == "__main__": unittest.main() -- GitLab