未验证 提交 eb0e4d4b 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] modify add_layernorm_xpu kernel (#56429)

上级 29efbdb6
......@@ -105,10 +105,12 @@ AddLayernormXPUPattern::AddLayernormXPUPattern(PDPattern* pattern,
->assert_is_op_input("layer_norm", "Scale");
auto norm_mean = pattern->NewNode(norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean");
->assert_is_op_output("layer_norm", "Mean")
->assert_has_n_outputs(0);
auto norm_variance = pattern->NewNode(norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance");
->assert_is_op_output("layer_norm", "Variance")
->assert_has_n_outputs(0);
auto norm_out = pattern->NewNode(norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
......@@ -198,6 +200,7 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
fused_op_out_name = norm_out->Name();
// Generate add_layernorm fused op
framework::OpDesc fused_op_desc(block);
fused_op_desc.SetType("add_layernorm_xpu");
// set attrs for fused op
fused_op_desc.SetInput("x", {add_x->Name()});
......@@ -207,9 +210,6 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
fused_op_desc.SetAttr("epsilon", eps);
fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis);
fused_op_desc.SetOutput("out", {fused_op_out_name});
setIntermediateOut(&fused_op_desc, "mean", name_scope_);
setIntermediateOut(&fused_op_desc, "variance", name_scope_);
setIntermediateOut(&fused_op_desc, "z_add", name_scope_);
// relink fused op
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(add_x, fused_op);
......@@ -217,9 +217,7 @@ void AddLayernormXPUFusePass::FuseAddLayernorm(ir::Graph* graph) const {
IR_NODE_LINK_TO(norm_scale, fused_op);
IR_NODE_LINK_TO(norm_bias, fused_op);
IR_NODE_LINK_TO(fused_op, norm_out);
addIntermediateOut(fused_op, "mean", name_scope_, graph);
addIntermediateOut(fused_op, "variance", name_scope_, graph);
addIntermediateOut(fused_op, "z_add", name_scope_, graph);
delete_nodes.insert({ele_add, l_norm, ele_out, norm_mean, norm_variance});
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
......
......@@ -84,16 +84,14 @@ FastLayernormXPUPattern::FastLayernormXPUPattern(PDPattern* pattern,
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto norm_mean =
pattern->NewNode(norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean")
->assert_more([](Node* node) { return node->outputs.size() == 0; });
auto norm_variance =
pattern->NewNode(norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance")
->assert_more([](Node* node) { return node->outputs.size() == 0; });
auto norm_mean = pattern->NewNode(norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean")
->assert_has_n_outputs(0);
auto norm_variance = pattern->NewNode(norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance")
->assert_has_n_outputs(0);
auto norm_out = pattern->NewNode(norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
......
......@@ -16,7 +16,7 @@
- op : add_layernorm_xpu
args : (Tensor x, Tensor y, Tensor scale, Tensor bias, int begin_norm_axis, float epsilon)
output : Tensor(out), Tensor(mean), Tensor(variance), Tensor(z_add)
output : Tensor(out)
infer_meta :
func : AddLayernormXPUInferMeta
kernel :
......
......@@ -98,10 +98,7 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
const MetaTensor& bias,
int begin_norm_axis,
float epsilon,
MetaTensor* out,
MetaTensor* mean,
MetaTensor* variance,
MetaTensor* z_add) {
MetaTensor* out) {
int axis = -1;
auto x_dims = x.dims();
auto y_dims = y.dims();
......@@ -112,21 +109,9 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
} else {
out->set_dims(out_dims);
}
auto layer_norm_x_mat_dims = phi::flatten_to_2d(out_dims, begin_norm_axis);
int64_t m = layer_norm_x_mat_dims[0];
int64_t n = layer_norm_x_mat_dims[1];
out->set_dtype(x.dtype());
out->set_layout(x.layout());
out->share_lod(x);
mean->set_dims(phi::make_ddim({m}));
mean->set_dtype(DataType::FLOAT32);
mean->set_layout(x.layout());
variance->set_dims(phi::make_ddim({m}));
variance->set_dtype(DataType::FLOAT32);
variance->set_layout(x.layout());
z_add->set_dims(phi::make_ddim({m, n}));
z_add->set_dtype(x.dtype());
z_add->set_layout(x.layout());
}
inline int ConvOutSize(int input_size,
......
......@@ -36,10 +36,7 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
const MetaTensor& bias,
int begin_norm_axis,
float epsilon,
MetaTensor* out,
MetaTensor* mean,
MetaTensor* variance,
MetaTensor* z_add);
MetaTensor* out);
void Conv1dXPUInferMeta(const MetaTensor& x,
const MetaTensor& x_max,
......
......@@ -73,19 +73,13 @@ void AddLayernormXPUKernel(const Context& ctx,
const DenseTensor& bias,
int begin_norm_axis,
float epsilon,
DenseTensor* out,
DenseTensor* mean,
DenseTensor* variance,
DenseTensor* z_add) {
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
auto* y_data = reinterpret_cast<const XPUType*>(y.data<T>());
const float* scale_data = scale.data<float>();
const float* bias_data = bias.data<float>();
float* mean_data = ctx.template Alloc<float>(mean);
float* variance_data = ctx.template Alloc<float>(variance);
auto* z_add_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(z_add));
auto x_dims = x.dims();
auto y_dims = y.dims();
......@@ -106,10 +100,10 @@ void AddLayernormXPUKernel(const Context& ctx,
/* float epsilon */ epsilon,
/* const float* scale */ scale_data,
/* const float* bias */ bias_data,
/* float* mean */ mean_data,
/* float* variance */ variance_data,
/* T* z_add */ z_add_data);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add_layernorm_xpu");
/* float* mean */ nullptr,
/* float* variance */ nullptr,
/* T* z_add */ nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "add_layer_norm_fusion");
}
} // namespace fusion
......
......@@ -39,10 +39,7 @@ void FastLayerNormXPUKernel(const Context& ctx,
const float* scale_data_fp32 = nullptr;
const auto* scale_ptr = scale.get_ptr();
if (scale_ptr == nullptr) {
float* scale_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(right);
int r = xpu::constant<float>(ctx.x_context(), scale_data_temp, right, 1.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
scale_data_fp32 = scale_data_temp;
// no scale, do nothing
} else if (scale_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* scale_data_temp =
......@@ -63,10 +60,7 @@ void FastLayerNormXPUKernel(const Context& ctx,
const float* bias_data_fp32 = nullptr;
const auto* bias_ptr = bias.get_ptr();
if (bias_ptr == nullptr) {
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(right);
int r = xpu::constant<float>(ctx.x_context(), bias_data_temp, right, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
bias_data_fp32 = bias_data_temp;
// no bias, do nothing
} else if (bias_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
......
......@@ -183,15 +183,9 @@ __global__ void fast_layer_norm_tiny_common(float epsilon,
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t mstart = 0;
int64_t mend = 0;
partition(tid, nthreads, m, 1, &mstart, &mend);
if (mstart >= mend) {
return;
}
float one_div_n = 1.0f / n;
constexpr int lm_buffer_size = 832 * sizeof(float) / sizeof(T);
constexpr int lm_buffer_size = 1664 * sizeof(float) / sizeof(T);
constexpr int sm_buffer_size = 1664 * 16;
__simd__ T xlm[lm_buffer_size];
__simd__ __shared__ float scale_sm[sm_buffer_size];
......@@ -207,7 +201,7 @@ __global__ void fast_layer_norm_tiny_common(float epsilon,
GM2SM(bias, bias_sm, n * sizeof(float));
}
sync_all();
for (int64_t i = mstart; i < mend; i += 1) {
for (int64_t i = tid; i < m; i += nthreads) {
GM2LM(x + i * n, xlm, n * sizeof(T));
update_sum_and_squaresum<T>(xlm, n, &sum, &squaresum);
float sample_mean = sum * one_div_n;
......
......@@ -85,7 +85,7 @@ static int xpu2_wrapper(Context* ctx,
const float* scale,
const float* bias) {
if (n <= 832) {
if (n % 32 == 0) {
if (n % 32 == 0 && n < 128) {
xpu2::plugin::fast_layer_norm_tiny_align32<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
eps, m, n, x, y, scale, bias);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册