未验证 提交 f16e1869 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add fast_layernorm_xpu_fuse_pass and fast_layernorm_xpu plugin (#56269)

上级 be22021c
......@@ -284,6 +284,8 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(gather_squeeze_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()
cc_library(
......
......@@ -372,7 +372,7 @@ void LayerNormFusePass::ApplyImpl(Graph* graph) const {
// ------------------ op creation and placement ---------------------------
OpDesc ln_op_desc;
OpDesc ln_op_desc(x_mean->Op()->Block());
ln_op_desc.SetType("layer_norm");
ln_op_desc.SetInput("X", {x->Name()});
ln_op_desc.SetInput("Scale", {new_gamma_node->Name()});
......
......@@ -119,7 +119,9 @@ int DeleteIsolatedNodePass::RemoveIsolatedNodes(
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
block = node->Op()->Block();
break;
if (block != nullptr) {
break;
}
}
}
Scope& scope = graph->Get<framework::Scope>("__param_scope__");
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
/*
change layernorm op to fast_layernorm op
For example:
graph:
x
|
layernorm
|
output
------------------------------------------------------
After the pass is applied:
x
|
fast_layernorm_xpu
|
output
*/
struct FastLayernormXPUPattern : public PatternBase {
FastLayernormXPUPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(l_norm);
// declare variable node's name
PATTERN_DECL_NODE(norm_in);
PATTERN_DECL_NODE(norm_bias);
PATTERN_DECL_NODE(norm_scale);
PATTERN_DECL_NODE(norm_mean);
PATTERN_DECL_NODE(norm_variance);
PATTERN_DECL_NODE(norm_out);
};
FastLayernormXPUPattern::FastLayernormXPUPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto l_norm = pattern->NewNode(l_norm_repr())->assert_is_op("layer_norm");
auto norm_in = pattern->NewNode(norm_in_repr())
->AsInput()
->assert_is_op_input("layer_norm", "X");
auto norm_bias = pattern->NewNode(norm_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Bias");
auto norm_scale = pattern->NewNode(norm_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("layer_norm", "Scale");
auto norm_mean =
pattern->NewNode(norm_mean_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Mean")
->assert_more([](Node* node) { return node->outputs.size() == 0; });
auto norm_variance =
pattern->NewNode(norm_variance_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Variance")
->assert_more([](Node* node) { return node->outputs.size() == 0; });
auto norm_out = pattern->NewNode(norm_out_repr())
->AsOutput()
->assert_is_op_output("layer_norm", "Y");
l_norm->LinksFrom({norm_in, norm_bias, norm_scale})
.LinksTo({norm_out, norm_mean, norm_variance});
}
} // namespace patterns
class FastLayernormXPUFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void FuseFastLayernorm(ir::Graph* graph) const;
const std::string name_scope_{"fast_layernorm_xpu_fuse_pass"};
};
void FastLayernormXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
FuseFastLayernorm(graph);
}
void FastLayernormXPUFusePass::FuseFastLayernorm(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::FastLayernormXPUPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle FastLayernormXPUFusePass";
// declare operator node's name
GET_IR_NODE(l_norm);
// declare variable node's name
GET_IR_NODE(norm_in);
GET_IR_NODE(norm_bias);
GET_IR_NODE(norm_scale);
GET_IR_NODE(norm_mean);
GET_IR_NODE(norm_variance);
GET_IR_NODE(norm_out);
auto* block = l_norm->Op()->Block();
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
// delete useless node
std::unordered_set<const Node*> delete_nodes;
float eps = PADDLE_GET_CONST(float, l_norm->Op()->GetAttr("epsilon"));
int begin_norm_axis =
PADDLE_GET_CONST(int, l_norm->Op()->GetAttr("begin_norm_axis"));
// Generate fast_layernorm_xpu op
framework::OpDesc fused_op_desc(block);
fused_op_desc.SetType("fast_layernorm_xpu");
fused_op_desc.SetInput("x", {norm_in->Name()});
fused_op_desc.SetInput("scale", {norm_scale->Name()});
fused_op_desc.SetInput("bias", {norm_bias->Name()});
fused_op_desc.SetAttr("epsilon", eps);
fused_op_desc.SetAttr("begin_norm_axis", begin_norm_axis);
fused_op_desc.SetOutput("out", {norm_out->Name()});
auto* fused_op = graph->CreateOpNode(&fused_op_desc);
IR_NODE_LINK_TO(norm_in, fused_op);
IR_NODE_LINK_TO(norm_scale, fused_op);
IR_NODE_LINK_TO(norm_bias, fused_op);
IR_NODE_LINK_TO(fused_op, norm_out);
delete_nodes.insert({l_norm, norm_mean, norm_variance});
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fast_layernorm_xpu_fuse_pass,
paddle::framework::ir::FastLayernormXPUFusePass);
REGISTER_PASS_CAPABILITY(fast_layernorm_xpu_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"layer_norm", 0));
......@@ -547,6 +547,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"conv2d_transpose_xpu_fuse_pass",
"add_activation_xpu_fuse_pass",
"add_layernorm_xpu_fuse_pass",
"fast_layernorm_xpu_fuse_pass",
"yolo_box_xpu_fuse_pass",
"fast_where_xpu_fuse_pass",
"link_xpu_op_max_pass",
......
......@@ -63,6 +63,15 @@
data_type: tables
optional : mask, seq_lod, max_seq_len
- op : fast_layernorm_xpu
args : (Tensor x, Tensor scale, Tensor bias, int begin_norm_axis, float epsilon)
output : Tensor(out)
infer_meta :
func : FastLayernormXPUInferMeta
kernel :
func : fast_layernorm_xpu
data_type : x
- op : fast_where_xpu
args : (Tensor condition, Tensor x, Tensor y)
output : Tensor(out)
......
......@@ -306,6 +306,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::INT32,
phi::DataType::FLOAT32,
phi::DataType::FLOAT16})},
{"fast_layernorm_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fc_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fill",
......
......@@ -825,4 +825,15 @@ void FastWhereXPUInferMeta(const MetaTensor& condition,
out->set_dtype(x.dtype());
}
void FastLayernormXPUInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
int begin_norm_axis,
float epsilon,
MetaTensor* out) {
out->set_dims(x.dims());
out->set_dtype(x.dtype());
out->set_layout(x.layout());
}
} // namespace phi
......@@ -197,4 +197,11 @@ void FastWhereXPUInferMeta(const MetaTensor& condition,
const MetaTensor& y,
MetaTensor* out);
void FastLayernormXPUInferMeta(const MetaTensor& x,
const MetaTensor& scale,
const MetaTensor& bias,
int begin_norm_axis,
float epsilon,
MetaTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FastLayerNormXPUKernel(const Context& ctx,
const DenseTensor& x,
const paddle::optional<DenseTensor>& scale,
const paddle::optional<DenseTensor>& bias,
int begin_norm_axis,
float epsilon,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const auto& x_dims = x.dims();
auto matrix_dim = phi::flatten_to_2d(x_dims, begin_norm_axis);
int left = static_cast<int>(matrix_dim[0]);
int right = static_cast<int>(matrix_dim[1]);
const auto* x_data = x.data<T>();
xpu::ctx_guard RAII_GUARD(ctx.x_context());
// scale
const float* scale_data_fp32 = nullptr;
const auto* scale_ptr = scale.get_ptr();
if (scale_ptr == nullptr) {
float* scale_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(right);
int r = xpu::constant<float>(ctx.x_context(), scale_data_temp, right, 1.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
scale_data_fp32 = scale_data_temp;
} else if (scale_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* scale_data_temp =
RAII_GUARD.alloc_l3_or_gm<float>(scale_ptr->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(scale_ptr->data<T>()),
scale_data_temp,
scale_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
scale_data_fp32 = scale_data_temp;
} else {
// no need to cast
scale_data_fp32 = scale_ptr->data<float>();
}
// bias
const float* bias_data_fp32 = nullptr;
const auto* bias_ptr = bias.get_ptr();
if (bias_ptr == nullptr) {
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(right);
int r = xpu::constant<float>(ctx.x_context(), bias_data_temp, right, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
bias_data_fp32 = bias_data_temp;
} else if (bias_ptr->dtype() ==
phi::CppTypeToDataType<phi::dtype::float16>::Type()) {
float* bias_data_temp = RAII_GUARD.alloc_l3_or_gm<float>(bias_ptr->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(bias_ptr->data<T>()),
bias_data_temp,
bias_ptr->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data_fp32 = bias_data_temp;
} else {
// no need to cast
bias_data_fp32 = bias_ptr->data<float>();
}
auto* out_data = ctx.template Alloc<T>(out);
#ifdef PADDLE_WITH_XPU_PLUGIN
int r = xpu::plugin::fast_layer_norm(ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(out_data),
left,
right,
epsilon,
scale_data_fp32,
bias_data_fp32);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_layer_norm");
#else
// int layer_norm(Context* ctx, const T* x, T* y, int64_t m, int64_t n, float
// eps, const float* scale, const float* bias, float* mean, float* var, bool
// is_rstd = false);
int r = xpu::layer_norm(ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(out_data),
left,
right,
epsilon,
scale_data_fp32,
bias_data_fp32,
nullptr,
nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "layer_norm");
#endif
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fast_layernorm_xpu,
XPU,
ALL_LAYOUT,
phi::fusion::FastLayerNormXPUKernel,
float,
phi::dtype::float16) {}
......@@ -66,6 +66,15 @@ DLL_EXPORT int take_along_axis(Context* ctx,
const std::vector<int64_t>& idxshape,
int64_t axis);
template <typename T>
DLL_EXPORT int fast_layer_norm(Context* ctx,
const T* x,
T* y,
int64_t m,
int64_t n,
float eps,
const float* scale,
const float* bias);
} // namespace plugin
} // namespace api
} // namespace xpu
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
static inline __device__ float sum16(const float* ptr) {
float s0 = ptr[0] + ptr[8];
float s1 = ptr[1] + ptr[9];
float s2 = ptr[2] + ptr[10];
float s3 = ptr[3] + ptr[11];
float s4 = ptr[4] + ptr[12];
float s5 = ptr[5] + ptr[13];
float s6 = ptr[6] + ptr[14];
float s7 = ptr[7] + ptr[15];
s0 = s0 + s1;
s2 = s2 + s3;
s4 = s4 + s5;
s6 = s6 + s7;
s0 = s0 + s2;
s4 = s4 + s6;
return s0 + s4;
}
template <typename T>
static __device__ void update_sum_and_squaresum(T* a,
int size,
float* sum,
float* squaresum) {
__simd__ float sum_buf[16];
__simd__ float squaresum_buf[16];
float32x16_t al;
float32x16_t ah;
int rounddown_size = rounddown32(size - 1);
unsigned int mask = -1;
if ((size % 32) != 0) {
mask = ~(-1 << (size % 32));
}
vload2_lm_mz(a + rounddown_size, al, ah, mask);
float32x16_t vsum = vvadd_float32x16(al, ah);
al = vvmul_float32x16(al, al);
ah = vvmul_float32x16(ah, ah);
float32x16_t vsquaresum = vvadd_float32x16(al, ah);
for (int i = 0; i < rounddown_size; i += 32) {
vload2_lm(a + i, al, ah);
vsum = vvadd_float32x16(vsum, al);
vsum = vvadd_float32x16(vsum, ah);
al = vvmul_float32x16(al, al);
ah = vvmul_float32x16(ah, ah);
vsquaresum = vvadd_float32x16(vsquaresum, al);
vsquaresum = vvadd_float32x16(vsquaresum, ah);
}
vstore_lm_float32x16(sum_buf, vsum);
vstore_lm_float32x16(squaresum_buf, vsquaresum);
mfence_lm();
*sum = sum16(sum_buf);
*squaresum = sum16(squaresum_buf);
}
template <typename T>
static __device__ void vector_scale_and_bias_align32(
T* a,
int size,
float mean,
float var,
_shared_ptr_ const float* scale_sm,
_shared_ptr_ const float* bias_sm,
bool do_scale_bias) {
float32x16_t al;
float32x16_t ah;
float32x16_t bl;
float32x16_t bh;
mean = 0.0f - mean;
if (do_scale_bias) {
// ((a + b) - mean) * var * scale + bias
for (int i = 0; i < size; i += 32) {
vload2_lm(a + i, al, ah);
vload2_sm(scale_sm + i, bl, bh);
al = svadd_float32x16(mean, al);
ah = svadd_float32x16(mean, ah);
al = svmul_float32x16(var, al);
ah = svmul_float32x16(var, ah);
al = vvmul_float32x16(bl, al);
ah = vvmul_float32x16(bh, ah);
vload2_sm(bias_sm + i, bl, bh);
al = vvadd_float32x16(bl, al);
ah = vvadd_float32x16(bh, ah);
vstore2_lm(a + i, al, ah);
}
} else {
// ((a + b) - mean) * var
for (int i = 0; i < size; i += 32) {
vload2_lm(a + i, al, ah);
al = svadd_float32x16(mean, al);
ah = svadd_float32x16(mean, ah);
al = svmul_float32x16(var, al);
ah = svmul_float32x16(var, ah);
vstore2_lm(a + i, al, ah);
}
}
mfence_lm();
}
template <typename T>
__global__ void fast_layer_norm_tiny_align32(float epsilon,
int64_t m,
int64_t n,
const T* x,
T* y,
const float* scale,
const float* bias) {
int cid = core_id();
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t mstart = 0;
int64_t mend = 0;
partition(tid, nthreads, m, 1, &mstart, &mend);
if (mstart >= mend) {
return;
}
float one_div_n = 1.0f / n;
constexpr int lm_buffer_size = 1664 * sizeof(float) / sizeof(T);
constexpr int sm_buffer_size = 1664 * 16;
__simd__ T xlm[lm_buffer_size];
__simd__ __shared__ float scale_sm[sm_buffer_size];
__simd__ __shared__ float bias_sm[sm_buffer_size];
int block_cnt = lm_buffer_size / n;
float sum = 0.0f;
float squaresum = 0.0f;
bool do_scale_bias = false;
if (scale != nullptr && bias != nullptr) {
do_scale_bias = true;
}
if (cid == 0 && do_scale_bias) {
GM2SM_ASYNC(scale, scale_sm, n * sizeof(float));
GM2SM(bias, bias_sm, n * sizeof(float));
}
sync_all();
for (int64_t i = mstart; i < mend; i += block_cnt) {
int readlen = min((mend - i) * n, block_cnt * n);
GM2LM(x + i * n, xlm, readlen * sizeof(T));
for (int64_t j = 0; j < readlen; j += n) {
update_sum_and_squaresum<T>(xlm + j, n, &sum, &squaresum);
float sample_mean = sum * one_div_n;
float sample_var = squaresum * one_div_n - sample_mean * sample_mean;
float rstd = 1.0f / sqrt(sample_var + epsilon);
vector_scale_and_bias_align32<T>(
xlm + j, n, sample_mean, rstd, scale_sm, bias_sm, do_scale_bias);
}
LM2GM(xlm, y + i * n, readlen * sizeof(T));
}
}
template <typename T>
__global__ void fast_layer_norm_tiny_common(float epsilon,
int64_t m,
int64_t n,
const T* x,
T* y,
const float* scale,
const float* bias) {
int cid = core_id();
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t mstart = 0;
int64_t mend = 0;
partition(tid, nthreads, m, 1, &mstart, &mend);
if (mstart >= mend) {
return;
}
float one_div_n = 1.0f / n;
constexpr int lm_buffer_size = 832 * sizeof(float) / sizeof(T);
constexpr int sm_buffer_size = 1664 * 16;
__simd__ T xlm[lm_buffer_size];
__simd__ __shared__ float scale_sm[sm_buffer_size];
__simd__ __shared__ float bias_sm[sm_buffer_size];
float sum = 0.0f;
float squaresum = 0.0f;
bool do_scale_bias = false;
if (scale != nullptr && bias != nullptr) {
do_scale_bias = true;
}
if (cid == 0 && do_scale_bias) {
GM2SM_ASYNC(scale, scale_sm, n * sizeof(float));
GM2SM(bias, bias_sm, n * sizeof(float));
}
sync_all();
for (int64_t i = mstart; i < mend; i += 1) {
GM2LM(x + i * n, xlm, n * sizeof(T));
update_sum_and_squaresum<T>(xlm, n, &sum, &squaresum);
float sample_mean = sum * one_div_n;
float sample_var = squaresum * one_div_n - sample_mean * sample_mean;
float rstd = 1.0f / sqrt(sample_var + epsilon);
vector_scale_and_bias_align32<T>(
xlm, n, sample_mean, rstd, scale_sm, bias_sm, do_scale_bias);
LM2GM(xlm, y + i * n, n * sizeof(T));
}
}
#define _XPU_DEF__FAST_LAYER_NORM_TINY_(DTYPE) \
template __global__ void fast_layer_norm_tiny_common<DTYPE>( \
float epsilon, \
int64_t m, \
int64_t n, \
const DTYPE* x, \
DTYPE* y, \
const float* scale, \
const float* bias); \
template __global__ void fast_layer_norm_tiny_align32<DTYPE>( \
float epsilon, \
int64_t m, \
int64_t n, \
const DTYPE* x, \
DTYPE* y, \
const float* scale, \
const float* bias);
_XPU_DEF__FAST_LAYER_NORM_TINY_(float16);
_XPU_DEF__FAST_LAYER_NORM_TINY_(float);
} // namespace plugin
} // namespace xpu2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace xpu2 {
namespace plugin {
template <typename T>
__attribute__((global)) void fast_layer_norm_tiny_common(float epsilon,
int64_t m,
int64_t n,
const T* x,
T* y,
const float* scale,
const float* bias);
template <typename T>
__attribute__((global)) void fast_layer_norm_tiny_align32(float epsilon,
int64_t m,
int64_t n,
const T* x,
T* y,
const float* scale,
const float* bias);
} // namespace plugin
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
template <typename T>
static int cpu_wrapper(Context* ctx,
const T* x,
T* y,
int64_t m,
int64_t n,
float eps,
const float* scale,
const float* bias) {
for (int64_t i = 0; i < m; i++) {
float sum = 0.0f;
float square_sum = 0.0f;
for (int64_t j = 0; j < n; j++) {
float v = static_cast<float>(x[i * n + j]);
sum += v;
square_sum += v * v;
}
float mean_value = sum / n;
float var_value = square_sum / n - mean_value * mean_value;
float rstd = 1.0f / std::sqrt(var_value + eps);
for (int64_t j = 0; j < n; j++) {
float v = static_cast<float>(x[i * n + j]);
float scale_value = ((scale == nullptr) ? 1.0f : scale[j]);
float bias_value = ((bias == nullptr) ? 0.0f : bias[j]);
float out = (v - mean_value) * rstd * scale_value + bias_value;
y[i * n + j] = static_cast<T>(out);
}
}
return SUCCESS;
}
template <typename T>
static int xpu2_wrapper(Context* ctx,
const T* x,
T* y,
int64_t m,
int64_t n,
float eps,
const float* scale,
const float* bias) {
if (n <= 832) {
if (n % 32 == 0) {
xpu2::plugin::fast_layer_norm_tiny_align32<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
eps, m, n, x, y, scale, bias);
} else {
xpu2::plugin::fast_layer_norm_tiny_common<T>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
eps, m, n, x, y, scale, bias);
}
} else {
return layer_norm(ctx, x, y, m, n, eps, scale, bias, nullptr, nullptr);
}
return SUCCESS;
}
template <typename T>
int fast_layer_norm(Context* ctx,
const T* x,
T* y,
int64_t m,
int64_t n,
float eps,
const float* scale,
const float* bias) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T1(ctx, "fast_layer_norm", T);
WRAPPER_DUMP_PARAM5(ctx, x, y, m, n, eps);
WRAPPER_DUMP_PARAM2(ctx, scale, bias);
WRAPPER_DUMP(ctx);
int64_t xylen = -1;
WRAPPER_CHECK_SHAPE(ctx, &xylen, {m, n});
WRAPPER_CHECK_2PTRS(ctx, T, xylen, x, y);
WRAPPER_ASSERT_GE(ctx, eps, 0);
WRAPPER_CHECK_PTR_OR_NULL(ctx, float, n, scale);
WRAPPER_CHECK_PTR_OR_NULL(ctx, float, n, bias);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(ctx, x, y, m, n, eps, scale, bias);
}
if (ctx->dev().type() == api::kXPU2) {
return xpu2_wrapper<T>(ctx, x, y, m, n, eps, scale, bias);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int fast_layer_norm(Context*,
const float*,
float*,
int64_t,
int64_t,
float,
const float*,
const float*);
template int fast_layer_norm(Context*,
const float16*,
float16*,
int64_t,
int64_t,
float,
const float*,
const float*);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
......@@ -64,10 +64,6 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
using XPUType = typename XPUTypeTrait<T>::Type;
int r = XPU_SUCCESS;
#ifndef PADDLE_WITH_XPU_PLUGIN
LOG(WARNING) << "Add -DWITH_XPU_PLUGIN=ON to build "
"xpu::plugin::take_along_axis(), or use "
"xpu::gather_element() instead, which leads low performance "
"in some cases.";
if (index_type == DataType::INT32) {
r = xpu::gather_element<XPUType, int>(
dev_ctx.x_context(),
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestFastLayernormXPUFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["fast_layernorm_xpu"], (1e-3, 1e-3)
def sample_program_config(self, draw):
batch_size = draw(st.integers(min_value=1, max_value=50))
x_shape = [batch_size, 16, 128]
y_shape = x_shape
axis = -1
epsilon = draw(st.floats(min_value=0.0000001, max_value=0.001))
begin_norm_axis = 2
layer_norm_op = OpConfig(
"layer_norm",
inputs={
"X": ["x"],
"Scale": ["layer_norm_scale"],
"Bias": ["layer_norm_bias"],
},
outputs={
"Y": ["layer_norm_out"],
"Mean": ["layer_norm_mean"],
"Variance": ["layer_norm_var"],
},
begin_norm_axis=begin_norm_axis,
epsilon=epsilon,
)
mini_graph = [layer_norm_op]
program_config = ProgramConfig(
ops=mini_graph,
weights={
"layer_norm_scale": TensorConfig(shape=[x_shape[2]]),
"layer_norm_bias": TensorConfig(shape=[x_shape[2]]),
},
inputs={
"x": TensorConfig(shape=x_shape),
},
outputs=mini_graph[-1].outputs["Y"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["fast_layernorm_xpu_fuse_pass"],
)
if __name__ == "__main__":
np.random.seed(200)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册