未验证 提交 071708fa 编写于 作者: T taixiurong 提交者: GitHub

xpu-paddlepaddle-41 [任务] ffn and attention test=kunlun (#46658)

上级 b4460eee
......@@ -38,6 +38,8 @@ if(WITH_XPU)
op_library(resnet_basic_block_op)
op_library(resnet_unit_op)
op_library(fused_gemm_epilogue_op)
op_library(fused_attention_op)
op_library(fused_feedforward_op)
endif()
if(WITH_GPU OR WITH_ROCM)
......
此差异已折叠。
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
using Tensor = phi::DenseTensor;
struct XPUDropoutParam {
float dropout_prob;
bool is_upscale_in_train;
bool is_test;
bool fix_seed;
const Tensor *tensor_seed;
int seed_val;
XPUDropoutParam() {
fix_seed = false;
is_test = false;
is_upscale_in_train = false;
dropout_prob = 0.5;
tensor_seed = nullptr;
seed_val = 0;
}
XPUDropoutParam(const framework::ExecutionContext &context,
const int dropout_index) {
std::string pre_fix = "dropout";
std::string str_index = std::to_string(dropout_index);
if (dropout_index > 0) {
pre_fix = pre_fix + str_index + "_";
} else {
pre_fix = pre_fix + "_";
}
dropout_prob = context.Attr<float>(pre_fix + "rate");
auto &dropout_implementation =
context.Attr<std::string>(pre_fix + "implementation");
is_upscale_in_train = (dropout_implementation == "upscale_in_train");
is_test = context.Attr<bool>("is_test");
fix_seed = context.Attr<bool>(pre_fix + "fix_seed");
std::string str_seed = "Dropout";
if (dropout_index > 0) {
str_seed = str_seed + str_index + "Seed";
} else {
str_seed = str_seed + "Seed";
}
tensor_seed =
context.HasInput(str_seed) ? context.Input<Tensor>(str_seed) : nullptr;
if (tensor_seed) {
seed_val = *(tensor_seed->data<int>());
} else {
seed_val = fix_seed ? context.Attr<int>(pre_fix + "seed") : 0;
}
}
void initXPUDropoutParam(float dropout_prob_,
bool is_upscale_in_train_,
bool is_test_,
bool fix_seed_,
const Tensor *tensor_seed,
int seed_val_) {
dropout_prob = dropout_prob_;
is_upscale_in_train = is_upscale_in_train_;
is_test = is_test_;
fix_seed = fix_seed_;
if (tensor_seed) {
seed_val = *(tensor_seed->data<int>());
} else {
seed_val = fix_seed ? seed_val_ : 0;
}
}
void initXPUDropoutParam(const framework::ExecutionContext &context,
int dropout_index) {
std::string pre_fix = "dropout";
std::string str_index = std::to_string(dropout_index);
if (dropout_index > 0) {
pre_fix = pre_fix + str_index + "_";
} else {
pre_fix = pre_fix + "_";
}
dropout_prob = context.Attr<float>(pre_fix + "rate");
auto &dropout_implementation =
context.Attr<std::string>(pre_fix + "implementation");
is_upscale_in_train = (dropout_implementation == "upscale_in_train");
is_test = context.Attr<bool>("is_test");
fix_seed = context.Attr<bool>(pre_fix + "fix_seed");
std::string str_seed = "Dropout";
if (dropout_index > 0) {
str_seed = str_seed + str_index + "Seed";
} else {
str_seed = str_seed + "Seed";
}
tensor_seed =
context.HasInput(str_seed) ? context.Input<Tensor>(str_seed) : nullptr;
if (tensor_seed) {
seed_val = *(tensor_seed->data<int>());
} else {
seed_val = fix_seed ? context.Attr<int>(pre_fix + "seed") : 0;
}
}
};
/******************
* check is l3
*******************/
static bool is_in_l3(const void *addr) {
int64_t addr_int = (int64_t)addr;
int addr_int_high = addr_int >> 32;
return (addr_int_high == 0);
}
/*************************
* dropout
*************************/
template <typename T>
void Dropout(xpu::Context *xpu_ctx,
const T *x,
T *mask,
T *y,
const XPUDropoutParam &param,
int len) {
using XPUType = typename XPUTypeTrait<T>::Type;
int r = XPU_SUCCESS;
if (param.dropout_prob == 0.0f) {
r = xpu::copy(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
if (!param.is_test) {
if (param.dropout_prob == 1.0f) {
r = xpu::constant(
xpu_ctx, reinterpret_cast<XPUType *>(y), len, XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
r = xpu::constant(
xpu_ctx, reinterpret_cast<XPUType *>(mask), len, XPUType(0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
} else {
r = xpu::dropout(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
reinterpret_cast<XPUType *>(mask),
param.seed_val,
len,
param.is_upscale_in_train,
param.dropout_prob);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout");
}
} else {
float scale = (param.is_upscale_in_train)
? (1.0)
: (static_cast<float>(1.0f - param.dropout_prob));
r = xpu::scale(xpu_ctx,
reinterpret_cast<const XPUType *>(x),
reinterpret_cast<XPUType *>(y),
len,
false,
scale,
0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale");
}
}
template <typename T>
void DropoutGrad(xpu::Context *xpu_ctx,
const T *dy,
const T *mask,
T *dx,
const XPUDropoutParam &param,
int len) {
using XPUType = typename XPUTypeTrait<T>::Type;
if (param.dropout_prob == 0.0f) {
int r = xpu::copy(xpu_ctx,
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<XPUType *>(dx),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
return;
}
if (!param.is_upscale_in_train) {
int r = xpu::mul(xpu_ctx,
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<const XPUType *>(mask),
reinterpret_cast<XPUType *>(dx),
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mul");
} else {
int r = xpu::dropout_grad(xpu_ctx,
reinterpret_cast<const XPUType *>(mask),
reinterpret_cast<const XPUType *>(dy),
reinterpret_cast<XPUType *>(dx),
param.dropout_prob,
len);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "dropout_grad");
}
}
} // namespace operators
} // namespace paddle
#endif
......@@ -704,6 +704,18 @@ XPUOpMap& get_kl2_ops() {
{"fused_gemm_epilogue_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"fused_attention",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"fused_attention_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"fused_feedforward",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"fused_feedforward_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
};
return s_xpu2_kernels;
......
......@@ -382,7 +382,8 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
const T* y,
T* out,
const XpuFcInfo& fcinfo,
float alpha) {
float alpha,
bool is_grad = false) {
using XPUType = typename XPUTypeTrait<T>::Type;
int fccal_type = FCCalcType<XPUType>();
......@@ -398,6 +399,12 @@ static void MatMulXPUFunction(xpu::Context* xpu_ctx,
};
auto fc_api = fc_api_list[fccal_type];
if (std::getenv("XPU_PADDLE_FC_GRAD_LOCAL") != nullptr) {
if (is_grad) {
fc_api = fc_api_list[2];
}
}
auto fc_batch_api = fc_batch_api_list[fccal_type];
int m = fcinfo.m;
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import sys
sys.path.append("..")
import paddle
import paddle.nn.functional as F
import paddle.incubate.nn.functional as incubate_f
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
from paddle.nn.layer.transformer import _convert_attention_mask
from paddle import tensor
from paddle.fluid import layers
import unittest
from op_test_xpu import XPUOpTest
from paddle.fluid.framework import default_main_program
from xpu.get_test_cover_info import (
create_test_class,
get_xpu_op_support_types,
XPUOpTestWrapper,
)
default_main_program().random_seed = 42
class XPUTestFusedAttentionOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'fused_attention'
self.use_dynamic_create_class = False
class TestFusedAttentionOp(XPUOpTest):
def setUp(self):
self.config()
self.generate_input_data()
self.rtol = 1e-5
self.atol = 1e-3
if self.x_type == np.float16 or str(self.x_type) == "float16":
self.atol = 1e-1
paddle.set_default_dtype(self.x_type)
self.__class__.op_type = "fused_attention"
# use autograd to check grad in this unittest.
self.__class__.no_need_check_grad = True
self.q_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr,
)
self.k_proj = Linear(
self.kdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr,
)
self.v_proj = Linear(
self.vdim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr,
)
self.out_proj = Linear(
self.embed_dim,
self.embed_dim,
self.weight_attr,
bias_attr=self.bias_attr,
)
paddle.set_default_dtype(np.float32)
self.norm1 = LayerNorm(self.embed_dim)
self.norm2 = LayerNorm(self.embed_dim)
paddle.set_default_dtype(self.x_type)
self.dropout = Dropout(self.dropout_prob, mode="upscale_in_train")
def config(self):
self.x_type = self.in_type
self.attn_mask_type = np.float32
self.pre_layer_norm = True
self.has_attn_mask = False
self.training = True
self.batch_size = 8
self.query_length = 128
self.cache_length = 128
self.head_dim = 64
self.num_heads = 16
self.embed_dim = self.head_dim * self.num_heads
self.dropout_prob = 0.0
self.attn_dropout_prob = 0.0
self.weight_attr = None
self.bias_attr = None
self.kdim, self.vdim = self.embed_dim, self.embed_dim
self.key_length, self.value_length = (
self.query_length,
self.query_length,
)
def generate_input_data(self):
self.query = np.random.rand(
self.batch_size, self.query_length, self.embed_dim
).astype(self.x_type)
out_seq_len = self.key_length
if self.has_attn_mask:
# [B, n_head, seq_len, out_seq_len]
self.attn_mask = np.ones(
(
self.batch_size,
self.num_heads,
self.query_length,
out_seq_len,
),
dtype=self.attn_mask_type,
)
else:
self.attn_mask = None
self.key, self.value = self.query, self.query
self.dout = np.random.random(
(self.batch_size, self.query_length, self.embed_dim)
).astype(self.x_type)
def GetBaselineOut(self):
paddle.disable_static()
tensor_query = paddle.to_tensor(self.query, stop_gradient=False)
if self.has_attn_mask:
attn_mask = paddle.to_tensor(
self.attn_mask, stop_gradient=False
)
else:
attn_mask = None
residual = tensor_query
ln1_out = tensor_query
if self.pre_layer_norm:
ln1_out = self.norm1(tensor_query)
q = self.q_proj(ln1_out)
q = tensor.reshape(x=q, shape=[0, 0, self.num_heads, self.head_dim])
q_out = tensor.transpose(x=q, perm=[0, 2, 1, 3])
k = self.k_proj(ln1_out)
v = self.v_proj(ln1_out)
k = tensor.reshape(x=k, shape=[0, 0, self.num_heads, self.head_dim])
k_out = tensor.transpose(x=k, perm=[0, 2, 1, 3])
v = tensor.reshape(x=v, shape=[0, 0, self.num_heads, self.head_dim])
v_out = tensor.transpose(x=v, perm=[0, 2, 1, 3])
# [B, n_head, seq_len, head_dim] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, out_seq_len]
qk_out = layers.matmul(
x=q_out * self.head_dim**-0.5, y=k_out, transpose_y=True
)
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, qk_out.dtype)
attn_mask_out = qk_out + attn_mask
softmax_out = F.softmax(attn_mask_out)
else:
softmax_out = F.softmax(qk_out)
if self.dropout_prob:
dropout_out = F.dropout(
softmax_out,
self.dropout_prob,
training=self.training,
mode="upscale_in_train",
)
# [B, n_head, seq_len, out_seq_len] * [B, n_head, out_seq_len, head_dim]
# --> [B, n_head, seq_len, head_dim]
qktv_out = tensor.matmul(dropout_out, v_out)
else:
qktv_out = tensor.matmul(softmax_out, v_out)
fmha_out = tensor.transpose(qktv_out, perm=[0, 2, 1, 3])
out_linear_in = tensor.reshape(
x=fmha_out, shape=[0, 0, fmha_out.shape[2] * fmha_out.shape[3]]
)
out = self.out_proj(out_linear_in)
residual_out = residual + self.dropout(out)
if not self.pre_layer_norm:
final_out = self.norm1(residual_out)
else:
final_out = residual_out
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True
)
return final_out, tensor_query.grad
def GetFusedAttentionOut(self):
paddle.disable_static()
q_proj_weight = paddle.to_tensor(
self.q_proj.weight, stop_gradient=False
)
k_proj_weight = paddle.to_tensor(
self.k_proj.weight, stop_gradient=False
)
v_proj_weight = paddle.to_tensor(
self.v_proj.weight, stop_gradient=False
)
out_linear_weight = paddle.to_tensor(
self.out_proj.weight, stop_gradient=False
)
if self.bias_attr is False:
qkv_bias_tensor = None
out_linear_bias = None
else:
q_proj_bias = paddle.to_tensor(
self.q_proj.bias, stop_gradient=False
)
k_proj_bias = paddle.to_tensor(
self.k_proj.bias, stop_gradient=False
)
v_proj_bias = paddle.to_tensor(
self.v_proj.bias, stop_gradient=False
)
qkv_bias = np.concatenate(
(
q_proj_bias.numpy(),
k_proj_bias.numpy(),
v_proj_bias.numpy(),
)
)
qkv_bias = qkv_bias.reshape((3, self.num_heads, self.head_dim))
qkv_bias_tensor = paddle.to_tensor(
qkv_bias, stop_gradient=False
)
out_linear_bias = paddle.to_tensor(
self.out_proj.bias, stop_gradient=False
)
ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)
q_proj_weight = q_proj_weight.numpy().transpose((1, 0))
k_proj_weight = k_proj_weight.numpy().transpose((1, 0))
v_proj_weight = v_proj_weight.numpy().transpose((1, 0))
qkv_weight = np.concatenate(
(q_proj_weight, k_proj_weight, v_proj_weight)
)
qkv_weight = qkv_weight.reshape(
(3, self.num_heads, self.head_dim, self.embed_dim)
)
x = paddle.to_tensor(self.query, stop_gradient=False)
cache_kv = None
if self.has_attn_mask:
attn_mask = paddle.to_tensor(
self.attn_mask, stop_gradient=False
)
else:
attn_mask = None
qkv_weight_tensor = paddle.to_tensor(
qkv_weight, stop_gradient=False
)
epsilon = 1e-05
ln2_epsilon = 1e-05
if attn_mask is not None:
attn_mask = _convert_attention_mask(attn_mask, x.dtype)
final_out = incubate_f.fused_multi_head_attention(
x,
qkv_weight_tensor,
out_linear_weight,
self.pre_layer_norm,
ln1_scale,
ln1_bias,
ln2_scale,
ln2_bias,
epsilon,
qkv_bias_tensor,
out_linear_bias,
cache_kv,
attn_mask,
self.dropout_prob,
self.attn_dropout_prob,
ln2_epsilon,
)
paddle.autograd.backward(
[final_out], [paddle.to_tensor(self.dout)], retain_graph=True
)
return final_out, x.grad
def test_fused_attention_op(self):
final_out_ref, x_grad_ref = self.GetBaselineOut()
final_out, x_grad = self.GetFusedAttentionOut()
np.testing.assert_allclose(
final_out_ref, final_out.numpy(), rtol=self.rtol, atol=self.atol
)
np.testing.assert_allclose(
x_grad_ref, x_grad.numpy(), rtol=self.rtol, atol=self.atol
)
class TestFusedAttentionOpPreLn(TestFusedAttentionOp):
def config(self):
super().config()
self.pre_layer_norm = True
class TestFusedAttentionOpNoneAttnMask(TestFusedAttentionOp):
def config(self):
super().config()
self.pre_layer_norm = True
self.has_attn_mask = False
support_types = get_xpu_op_support_types('fused_attention')
for stype in support_types:
create_test_class(globals(), XPUTestFusedAttentionOp, stype)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import sys
sys.path.append("..")
import paddle
from paddle.nn.layer import transformer
import paddle.nn.functional as F
import paddle.incubate.nn.functional as incubate_f
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
import unittest
from op_test_xpu import XPUOpTest
from paddle.fluid.framework import default_main_program
from xpu.get_test_cover_info import (
create_test_class,
XPUOpTestWrapper,
)
class XPUTestFusedFFNOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'fused_feedforward'
self.use_dynamic_create_class = False
class TestFusedFFNOp(XPUOpTest):
def getDtype(self):
self.dtype = self.in_type
self.layer_norm_dtype = "float32"
def getShape(self):
self.batch_size = np.random.randint(1, 32)
self.query_length = np.random.randint(32, 128)
self.d_model = np.random.randint(32, 512)
self.dim_feedforward = np.random.randint(32, 512)
def getDiff(self):
self.rtol = 1e-2
self.atol = 1e-3
if self.dtype == np.float16 or self.dtype == "float16":
self.atol = 1e-1
def getActivation(self):
self.act_method = "gelu"
def getNormalizeBefore(self):
self.pre_layer_norm = False
def setUp(self):
paddle.disable_static()
self.__class__.op_type = "fused_feedforward"
# check grad in test_out_and_grad()
self.__class__.no_need_check_grad = True
self.getDtype()
self.getShape()
self.getDiff()
self.getActivation()
self.getNormalizeBefore()
paddle.set_default_dtype(self.dtype)
self.weight_attr = None
self.bias_attr = None
self.weight_attrs = transformer._convert_param_attr_to_list(
self.weight_attr, 2
)
self.bias_attrs = transformer._convert_param_attr_to_list(
self.bias_attr, 2
)
self.linear1 = Linear(
self.d_model,
self.dim_feedforward,
self.weight_attrs[1],
bias_attr=self.bias_attrs[1],
)
self.linear2 = Linear(
self.dim_feedforward,
self.d_model,
self.weight_attrs[1],
bias_attr=self.bias_attrs[1],
)
paddle.set_default_dtype(self.layer_norm_dtype)
self.norm1 = LayerNorm(self.d_model)
self.norm2 = LayerNorm(self.d_model)
paddle.set_default_dtype(self.dtype)
self.dropout1 = Dropout(0.0, mode="upscale_in_train")
self.dropout2 = Dropout(0.0, mode="upscale_in_train")
self.activation = getattr(F, self.act_method)
self.src = np.random.random(
(self.batch_size, self.query_length, self.d_model)
).astype(self.dtype)
self.dout = np.random.random(
(self.batch_size, self.query_length, self.d_model)
).astype(self.dtype)
def Base(self):
paddle.disable_static()
tensor_src = paddle.to_tensor(self.src, stop_gradient=False)
residual = tensor_src
if self.pre_layer_norm:
ln1_out = self.norm1(tensor_src)
linear2_out = self.linear2(
self.dropout1(self.activation(self.linear1(ln1_out)))
)
dropout2_out = residual + self.dropout2(linear2_out)
paddle.autograd.backward(
[dropout2_out], [paddle.to_tensor(self.dout)], True
)
return dropout2_out, tensor_src.grad
else:
linear2_out = self.linear2(
self.dropout1(self.activation(self.linear1(tensor_src)))
)
dropout2_out = residual + self.dropout2(linear2_out)
dropout2_out = self.norm2(dropout2_out)
paddle.autograd.backward(
[dropout2_out], [paddle.to_tensor(self.dout)], True
)
return dropout2_out, tensor_src.grad
def FusedFFN(self):
paddle.disable_static()
linear1_weight = paddle.to_tensor(
self.linear1.weight, stop_gradient=False
)
linear1_bias = paddle.to_tensor(
self.linear1.bias, stop_gradient=False
)
linear2_weight = paddle.to_tensor(
self.linear2.weight, stop_gradient=False
)
linear2_bias = paddle.to_tensor(
self.linear2.bias, stop_gradient=False
)
ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)
x = paddle.to_tensor(self.src, stop_gradient=False)
out = incubate_f.fused_feedforward(
x,
linear1_weight,
linear2_weight,
linear1_bias,
linear2_bias,
ln1_scale,
ln1_bias,
ln2_scale,
ln2_bias,
0.0,
0.0,
activation=self.act_method,
pre_layer_norm=self.pre_layer_norm,
)
paddle.autograd.backward([out], [paddle.to_tensor(self.dout)])
return out, x.grad
def test_out_and_grad(self):
default_main_program().random_seed = 42
base_out, base_grad = self.Base()
fused_out, fused_grad = self.FusedFFN()
np.testing.assert_allclose(
base_out.numpy(),
fused_out.numpy(),
rtol=self.rtol,
atol=self.atol,
)
np.testing.assert_allclose(
base_grad.numpy(),
fused_grad.numpy(),
rtol=self.rtol,
atol=self.atol,
)
class TestFusedFFNOpActivation(TestFusedFFNOp):
def getActivation(self):
self.act_method = "relu"
class TestFusedFFNOpNormalizeBefore(TestFusedFFNOp):
def getNormalizeBefore(self):
self.pre_layer_norm = True
def getShape(self):
self.batch_size = 1
self.query_length = 1
self.d_model = 8
self.dim_feedforward = 8
class APITestStaticFusedFFN(unittest.TestCase):
def test_static(self):
paddle.enable_static()
default_main_program().random_seed = 42
dtype = "float32"
layer_norm_dtype = "float32"
batch_size = 1
d_model = 8
dim_feedforward = 8
x = paddle.static.data(
name='x', shape=[batch_size, d_model, dim_feedforward], dtype=dtype
)
linear1_weight = paddle.static.data(
name='linear1_weight', shape=[d_model, dim_feedforward], dtype=dtype
)
linear1_bias = paddle.static.data(
name='linear1_bias', shape=[dim_feedforward], dtype=dtype
)
linear2_weight = paddle.static.data(
name='linear2_weight', shape=[dim_feedforward, d_model], dtype=dtype
)
linear2_bias = paddle.static.data(name='linear2_bias', shape=[d_model])
ln1_scale = paddle.static.data(name='ln1_scale', shape=[d_model])
ln1_bias = paddle.static.data(name='ln1_scale', shape=[d_model])
ln2_scale = paddle.static.data(name='ln2_scale', shape=[d_model])
ln2_bias = paddle.static.data(name='ln2_scale', shape=[d_model])
fused_out = incubate_f.fused_feedforward(
x,
linear1_weight,
linear2_weight,
linear1_bias,
linear2_bias,
ln1_scale,
ln1_bias,
ln2_scale,
ln2_bias,
0.0,
0.0,
activation="relu",
pre_layer_norm=False,
)
linear1_out = F.linear(x, linear1_weight, linear1_bias)
act_out = F.relu(linear1_out)
dropout1_out = F.dropout(x=act_out, p=0.0, training=False)
linear2_out = F.linear(dropout1_out, linear2_weight, linear2_bias)
dropout2_out = x + F.dropout(x=linear2_out, p=0.0, training=False)
ln_out = F.layer_norm(
dropout2_out,
normalized_shape=list([d_model]),
weight=ln2_scale,
bias=ln2_bias,
)
exe = paddle.static.Executor(paddle.XPUPlace(0))
x_data = np.random.random(
(batch_size, d_model, dim_feedforward)
).astype(dtype)
linear1_weight_data = np.random.random(
(d_model, dim_feedforward)
).astype(dtype)
linear1_bias_data = np.zeros((dim_feedforward)).astype(dtype)
linear2_weight_data = np.random.random(
(dim_feedforward, d_model)
).astype(dtype)
linear2_bias_data = np.zeros((d_model)).astype(dtype)
ln1_scale_data = np.ones((d_model)).astype(layer_norm_dtype)
ln1_bias_data = np.zeros((d_model)).astype(layer_norm_dtype)
ln2_scale_data = np.ones((d_model)).astype(layer_norm_dtype)
ln2_bias_data = np.zeros((d_model)).astype(layer_norm_dtype)
res_list = [fused_out, ln_out]
real_res = []
for res in res_list:
fetch = exe.run(
feed={
'x': x_data,
'linear1_weight': linear1_weight_data,
'linear1_bias': linear1_bias_data,
'linear2_weight': linear2_weight_data,
'linear2_bias': linear2_bias_data,
'ln1_scale': ln1_scale_data,
'ln1_bias': ln1_bias_data,
'ln2_scale': ln2_scale_data,
'ln2_bias': ln2_bias_data,
},
fetch_list=[res],
)
real_res.append(fetch)
np.testing.assert_allclose(
real_res[0], real_res[1], rtol=1e-05, atol=0.001
)
class TestFusedFFNOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
def test_dtype():
x = paddle.static.data(
name='x', shape=[1, 10, 10], dtype="int32"
)
linear1_weight = paddle.static.data(
name='linear1_weight', shape=[1, 10, 10], dtype="float32"
)
linear2_weight = paddle.static.data(
name='linear2_weight', shape=[1, 10, 10], dtype="float32"
)
incubate_f.fused_feedforward(x, linear1_weight, linear2_weight)
self.assertRaises(TypeError, test_dtype)
def test_dropout_rate_type():
x = paddle.static.data(
name='x1', shape=[1, 10, 10], dtype="float32"
)
linear1_weight = paddle.static.data(
name='linear1_weight1', shape=[10, 10], dtype="float32"
)
linear2_weight = paddle.static.data(
name='linear2_weight1', shape=[10, 10], dtype="float32"
)
incubate_f.fused_feedforward(
x, linear1_weight, linear2_weight, dropout1_rate="a"
)
self.assertRaises(TypeError, test_dropout_rate_type)
def test_dropout_rate_value():
x = paddle.static.data(
name='x2', shape=[1, 10, 10], dtype="float32"
)
linear1_weight = paddle.static.data(
name='linear1_weight2', shape=[10, 10], dtype="float32"
)
linear2_weight = paddle.static.data(
name='linear2_weight2', shape=[10, 10], dtype="float32"
)
incubate_f.fused_feedforward(
x, linear1_weight, linear2_weight, dropout2_rate=-1
)
self.assertRaises(ValueError, test_dropout_rate_value)
def test_dropout_mode():
x = paddle.static.data(
name='x3', shape=[1, 10, 10], dtype="float32"
)
linear1_weight = paddle.static.data(
name='linear1_weight3', shape=[10, 10], dtype="float32"
)
linear2_weight = paddle.static.data(
name='linear2_weight3', shape=[10, 10], dtype="float32"
)
incubate_f.fused_feedforward(
x, linear1_weight, linear2_weight, mode='test'
)
self.assertRaises(ValueError, test_dropout_mode)
support_types = {"float32"} # get_xpu_op_support_types('fused_feedforward')
for stype in support_types:
create_test_class(globals(), XPUTestFusedFFNOp, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册