未验证 提交 fc880209 编写于 作者: C cmeng 提交者: GitHub

fuse vit attention for faster-rcnn on BML (#54139)

上级 25409dcc
......@@ -200,6 +200,13 @@ if(WITH_MKLDNN)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
pass_library(quant_dequant_mkldnn_pass inference DIR mkldnn)
pass_library(compute_propagate_scales_mkldnn_pass inference DIR mkldnn)
pass_library(self_attention_fuse_pass inference DIR mkldnn)
if(WITH_AVX
AND AVX512F_FOUND
AND AVX512F_FLAG)
set_target_properties(self_attention_fuse_pass
PROPERTIES COMPILE_FLAGS "-mfma ${AVX512F_FLAG}")
endif()
endif()
if(WITH_IPU)
......
......@@ -2615,6 +2615,81 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) {
return reshape2_out;
}
PDNode *patterns::SelfAttention::operator()(PDNode *in) {
in->AsInput();
std::unordered_set<std::string> matmul_ops{"matmul", "matmul_v2"};
auto transpose2_0_op =
pattern->NewNode(transpose2_0_op_repr())->assert_is_op("transpose2");
auto transpose2_0_out = pattern->NewNode(transpose2_0_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_op_input("slice", "Input")
->AsIntermediate();
auto slice_0_op = pattern->NewNode(slice_0_op_repr())->assert_is_op("slice");
auto slice_0_out = pattern->NewNode(slice_0_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_ops_input(matmul_ops, "X")
->AsIntermediate();
auto slice_1_op = pattern->NewNode(slice_1_op_repr())->assert_is_op("slice");
auto slice_1_out = pattern->NewNode(slice_1_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_input("transpose2", "X")
->AsIntermediate();
auto slice_2_op = pattern->NewNode(slice_2_op_repr())->assert_is_op("slice");
auto slice_2_out = pattern->NewNode(slice_2_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_ops_input(matmul_ops, "Y")
->AsIntermediate();
auto matmul_0_op =
pattern->NewNode(matmul_0_op_repr())->assert_is_ops(matmul_ops);
auto matmul_0_out = pattern->NewNode(matmul_0_out_repr())
->assert_is_ops_output(matmul_ops, "Out")
->assert_is_op_input("transpose2", "X")
->AsIntermediate();
auto matmul_1_op =
pattern->NewNode(matmul_1_op_repr())->assert_is_ops(matmul_ops);
auto matmul_1_out = pattern->NewNode(matmul_1_out_repr())
->assert_is_ops_output(matmul_ops, "Out")
->assert_is_op_input("softmax", "X")
->AsIntermediate();
auto transpose2_1_op =
pattern->NewNode(transpose2_1_op_repr())->assert_is_op("transpose2");
auto transpose2_1_out = pattern->NewNode(transpose2_1_out_repr())
->assert_is_op_output("transpose2", "Out")
->assert_is_ops_input(matmul_ops, "Y")
->AsIntermediate();
auto softmax_op =
pattern->NewNode(softmax_op_repr())->assert_is_op("softmax");
auto softmax_out = pattern->NewNode(softmax_out_repr())
->assert_is_op_output("softmax", "Out")
->assert_is_ops_input(matmul_ops, "X")
->AsIntermediate();
auto transpose2_2_op =
pattern->NewNode(transpose2_2_op_repr())->assert_is_op("transpose2");
auto transpose2_2_out = pattern->NewNode(transpose2_2_out_repr())
->assert_is_op_output("transpose2", "Out")
->AsOutput();
transpose2_0_op->LinksFrom({in});
transpose2_0_out->LinksFrom({transpose2_0_op});
slice_0_op->LinksFrom({transpose2_0_out});
slice_0_out->LinksFrom({slice_0_op});
slice_1_op->LinksFrom({transpose2_0_out});
slice_1_out->LinksFrom({slice_1_op});
slice_2_op->LinksFrom({transpose2_0_out});
slice_2_out->LinksFrom({slice_2_op});
transpose2_1_op->LinksFrom({slice_1_out});
transpose2_1_out->LinksFrom({transpose2_1_op});
matmul_1_op->LinksFrom({slice_0_out, transpose2_1_out});
matmul_1_out->LinksFrom({matmul_1_op});
softmax_op->LinksFrom({matmul_1_out});
softmax_out->LinksFrom({softmax_op});
matmul_0_op->LinksFrom({softmax_out, slice_2_out});
matmul_0_out->LinksFrom({matmul_0_op});
transpose2_2_op->LinksFrom({matmul_0_out});
transpose2_2_out->LinksFrom({transpose2_2_op});
return transpose2_2_out;
}
PDNode *patterns::ConvElementwiseadd2Act::operator()(
PDNode *conv_in, const std::unordered_set<std::string> &conv_act_set) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
......
......@@ -1491,6 +1491,33 @@ struct VitAttention : public PatternBase {
PATTERN_DECL_NODE(reshape2_out);
};
// self_attention in vit
struct SelfAttention : public PatternBase {
SelfAttention(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "vit_block") {}
PDNode* operator()(PDNode* in);
PATTERN_DECL_NODE(transpose2_0_op);
PATTERN_DECL_NODE(transpose2_0_out);
PATTERN_DECL_NODE(transpose2_1_op);
PATTERN_DECL_NODE(transpose2_1_out);
PATTERN_DECL_NODE(transpose2_2_op);
PATTERN_DECL_NODE(transpose2_2_out);
PATTERN_DECL_NODE(matmul_0_op);
PATTERN_DECL_NODE(matmul_0_out);
PATTERN_DECL_NODE(matmul_1_op);
PATTERN_DECL_NODE(matmul_1_out);
PATTERN_DECL_NODE(slice_0_op);
PATTERN_DECL_NODE(slice_0_out);
PATTERN_DECL_NODE(slice_1_op);
PATTERN_DECL_NODE(slice_1_out);
PATTERN_DECL_NODE(slice_2_op);
PATTERN_DECL_NODE(slice_2_out);
PATTERN_DECL_NODE(softmax_op);
PATTERN_DECL_NODE(softmax_out);
};
// Conv + ElementwiseAdd + an activation
// This pattern can further fuse the conv related ops after the conv+bn fusion.
struct ConvElementwiseaddAct : public PatternBase {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(transpose2_0_op); \
GET_IR_NODE(transpose2_0_out); \
GET_IR_NODE(slice_0_op); \
GET_IR_NODE(slice_0_out); \
GET_IR_NODE(slice_1_op); \
GET_IR_NODE(slice_1_out); \
GET_IR_NODE(slice_2_op); \
GET_IR_NODE(slice_2_out); \
GET_IR_NODE(matmul_0_op); \
GET_IR_NODE(matmul_0_out); \
GET_IR_NODE(matmul_1_op); \
GET_IR_NODE(matmul_1_out); \
GET_IR_NODE(transpose2_1_op); \
GET_IR_NODE(transpose2_1_out); \
GET_IR_NODE(softmax_op); \
GET_IR_NODE(softmax_out); \
GET_IR_NODE(transpose2_2_op); \
GET_IR_NODE(transpose2_2_out);
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void SelfAttentionFusePass::ApplyImpl(ir::Graph* graph) const {
#if !defined(__AVX512F__) || !defined(PADDLE_WITH_MKLML) || \
!defined(PADDLE_WITH_MKLDNN)
LOG(WARNING) << "No-avx512 or MKL supported!";
return;
#endif
// do something;
GraphPatternDetector gpd;
const std::string pattern_name = "self_attention_fuse";
FusePassBase::Init(pattern_name, graph);
// pattern
PDNode* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("transpose2", "X")
->AsInput();
patterns::SelfAttention pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
int fusion_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
// do something;
OpDesc desc(transpose2_0_op->Op()->Block());
desc.SetType("self_dp_attention");
desc.SetInput("X", {subgraph.at(x)->Name()});
desc.SetOutput("Out", {transpose2_2_out->Name()});
std::vector<int64_t> in_shape = subgraph.at(x)->Var()->GetShape();
std::vector<int64_t> shape = transpose2_0_out->Var()->GetShape();
// in shape should be [batch_size, seq_len, 3, num_heads, head_size]
if (in_shape.size() != 5 || in_shape[2] != 3 || shape.size() != 5 ||
shape[0] != 3 || shape[2] != in_shape[3]) {
LOG(WARNING) << "Self-attention shape mismatch!";
return;
}
desc.SetAttr("head_number", static_cast<int>(shape[2]));
float alpha = 1.0;
if (matmul_1_op->Op()->HasAttr("alpha"))
alpha = PADDLE_GET_CONST(float, matmul_1_op->Op()->GetAttr("alpha"));
desc.SetAttr("alpha", alpha);
// Create a new node for the fused op.
auto self_attention_node = graph->CreateOpNode(&desc);
// Link inputs and outputs.
PADDLE_ENFORCE_NE(subgraph.count(x),
0,
platform::errors::NotFound(
"Detector did not find input x of self attention."));
IR_NODE_LINK_TO(subgraph.at(x), self_attention_node); // Input
IR_NODE_LINK_TO(self_attention_node, transpose2_2_out); // Output
// Delete the unneeded nodes.
std::unordered_set<const Node*> marked_nodes({transpose2_0_op,
transpose2_0_out,
slice_0_op,
slice_0_out,
slice_1_op,
slice_1_out,
slice_2_op,
slice_2_out,
matmul_0_op,
matmul_0_out,
matmul_1_op,
matmul_1_out,
transpose2_1_op,
transpose2_1_out,
softmax_op,
softmax_out,
transpose2_2_op});
GraphSafeRemoveNodes(graph, marked_nodes);
++fusion_count;
};
gpd(graph, handler);
AddStatis(fusion_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
PrettyLogDetail(
"--- fused %d self attention (of scaled_dp_attention) with %s",
fusion_count,
pattern_name);
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(self_attention_fuse_pass,
paddle::framework::ir::SelfAttentionFusePass);
REGISTER_PASS_CAPABILITY(self_attention_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("transpose2", 0)
.EQ("slice", 0)
.EQ("scale", 0)
.EQ("softmax", 0)
.EQ("matmul_v2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
// Fusing of self-attetion structure
class Graph;
class SelfAttentionFusePass : public FusePassBase {
public:
virtual ~SelfAttentionFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -367,6 +367,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"fc_mkldnn_pass",
"fc_act_mkldnn_fuse_pass",
"fc_elementwise_add_mkldnn_fuse_pass", //
"self_attention_fuse_pass", //
"batch_norm_act_fuse_pass", //
"softplus_activation_onednn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", //
......
......@@ -11,6 +11,7 @@ register_operators(
fusion_conv_inception_op
fused_fc_elementwise_layernorm_op
multihead_matmul_op
self_dp_attention_op
skip_layernorm_op
yolo_box_head_op
yolo_box_post_op
......@@ -33,6 +34,14 @@ register_operators(
# fusion_gru_op does not have CUDA kernel
op_library(fusion_gru_op)
op_library(fusion_lstm_op)
if(WITH_AVX
AND AVX512F_FOUND
AND AVX512F_FLAG
AND WITH_MKL)
op_library(self_dp_attention_op)
set_target_properties(self_dp_attention_op PROPERTIES COMPILE_FLAGS
"-mfma ${AVX512F_FLAG}")
endif()
if(WITH_XPU)
op_library(resnet_basic_block_op)
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <assert.h>
#include <immintrin.h>
#include <math.h>
#include <omp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <iostream>
#include <new>
#include <string>
#ifdef PADDLE_WITH_MKLDNN
#include "dnnl.hpp" //NOLINT
#endif
namespace paddle {
namespace operators {
template <typename T, typename Tt>
void arraycpy(T* dst, const Tt* src, int n) {
#ifdef PADDLE_WITH_MKLML
#pragma omp simd
#endif
for (int i = 0; i < n; i++) {
dst[i] = static_cast<T>(src[i]);
}
}
// batchs x tokens x 3 x head x heads -> 3 x batchs x head x tokens x heads (2
// 0 3 1 4)
template <typename T, typename Tt>
void transpose_before_bmm1(const T* qkvBuffer,
Tt* qkvTransBuffer,
int batchSize,
int tokenSize,
int headNum,
int headSize) {
int hiddenSize = headNum * headSize;
int blocksize = tokenSize * hiddenSize; // dst buffer stride in each batch
const T* qBuffer = qkvBuffer;
const T* kBuffer = qkvBuffer + hiddenSize;
const T* vBuffer = qkvBuffer + hiddenSize * 2;
Tt* q_buffer = qkvTransBuffer;
Tt* k_buffer = qkvTransBuffer + batchSize * blocksize;
Tt* v_buffer = qkvTransBuffer + batchSize * blocksize * 2;
int bmHead = headNum;
int cols_per_bmHead = hiddenSize / headNum; // 768/12 = 64
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for (int i = 0; i < batchSize; i++) {
for (int k = 0; k < bmHead; k++) {
for (int j = 0; j < tokenSize; j++) {
const T* q_src_each_batch =
reinterpret_cast<const T*>(qBuffer) + blocksize * 3 * i;
const T* k_src_each_batch =
reinterpret_cast<const T*>(kBuffer) + blocksize * 3 * i;
const T* v_src_each_batch =
reinterpret_cast<const T*>(vBuffer) + blocksize * 3 * i;
int dst_offset_each_bmHead = k * tokenSize * cols_per_bmHead;
int src_offset_each_line = k * cols_per_bmHead;
int dst_offset_each_line = j * cols_per_bmHead;
int src_offset_each_bmHead = j * hiddenSize * 3;
Tt* q_dst_each_line = q_buffer + i * blocksize +
dst_offset_each_bmHead + dst_offset_each_line;
const T* q_src_each_line =
q_src_each_batch + src_offset_each_bmHead + src_offset_each_line;
Tt* k_dst_each_line = k_buffer + i * blocksize +
dst_offset_each_bmHead + dst_offset_each_line;
const T* k_src_each_line =
k_src_each_batch + src_offset_each_bmHead + src_offset_each_line;
Tt* v_dst_each_line = v_buffer + i * blocksize +
dst_offset_each_bmHead + dst_offset_each_line;
const T* v_src_each_line =
v_src_each_batch + src_offset_each_bmHead + src_offset_each_line;
arraycpy<Tt, T>(q_dst_each_line, q_src_each_line, cols_per_bmHead);
arraycpy<Tt, T>(k_dst_each_line, k_src_each_line, cols_per_bmHead);
arraycpy<Tt, T>(v_dst_each_line, v_src_each_line, cols_per_bmHead);
}
}
}
}
// batchs x head x tokens x heads -> batchs x tokens x head x heads (0 2 1 3)
template <typename T, typename Tt>
void transpose_after_bmm2(T* Buffer,
Tt* TransBuffer,
int batchSize,
int tokenSize,
int headNum,
int headSize) {
int hiddenSize = headNum * headSize;
int blocksize = tokenSize * hiddenSize; // dst buffer stride in each batch
int bmHead = headNum;
int cols_per_bmHead = hiddenSize / headNum; // 768/12 = 64
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for (int i = 0; i < batchSize; i++) {
for (int k = 0; k < tokenSize; k++) {
int src_offset_each_head = k * cols_per_bmHead;
int dst_offset_each_line = k * hiddenSize;
for (int j = 0; j < bmHead; j++) {
int src_offset_each_line = j * tokenSize * cols_per_bmHead;
int dst_offset_each_head = j * cols_per_bmHead;
Tt* q_dst_each_line = TransBuffer + dst_offset_each_head +
dst_offset_each_line + i * blocksize;
const T* q_src_each_line = Buffer + src_offset_each_line +
src_offset_each_head + i * blocksize;
arraycpy<Tt, T>(q_dst_each_line, q_src_each_line, cols_per_bmHead);
}
}
}
}
// C = A * B
// bTranspose: B need to be transposed or not
void sgemm(const float* A,
const float* B,
float* C,
int m,
int n,
int k,
bool transa,
bool transb) {
#ifdef PADDLE_WITH_MKLDNN
int lda = (transa ? m : k);
int ldb = (transb ? k : n);
int ldc = n;
float alpha = 1;
float beta = 0;
char ta[] = "N";
char tb[] = "N";
if (transa) ta[0] = 'T';
if (transb) tb[0] = 'T';
dnnl_sgemm(ta[0], tb[0], m, n, k, alpha, A, lda, B, ldb, beta, C, ldc);
#else
LOG(ERROR) << "scaled_dp_atten not supported without WITH_MKL!";
#endif
}
#if defined(__AVX512F__)
// exp based-on jit code
static inline __m512 vexp(const __m512& _x) {
__m512 p16f_1 = _mm512_set1_ps(1.0f);
__m512 p16f_half = _mm512_set1_ps(0.5f);
__m512 p16f_127 = _mm512_set1_ps(127.f);
__m512 p16f_exp_hi = _mm512_set1_ps(88.3762626647950f);
__m512 p16f_exp_lo = _mm512_set1_ps(-88.3762626647949f);
__m512 p16f_cephes_LOG2EF = _mm512_set1_ps(1.44269504088896341f);
__m512 p16f_cephes_exp_p0 = _mm512_set1_ps(1.9875691500E-4f);
__m512 p16f_cephes_exp_p1 = _mm512_set1_ps(1.3981999507E-3f);
__m512 p16f_cephes_exp_p2 = _mm512_set1_ps(8.3334519073E-3f);
__m512 p16f_cephes_exp_p3 = _mm512_set1_ps(4.1665795894E-2f);
__m512 p16f_cephes_exp_p4 = _mm512_set1_ps(1.6666665459E-1f);
__m512 p16f_cephes_exp_p5 = _mm512_set1_ps(5.0000001201E-1f);
// Clamp x.
__m512 x = _mm512_max_ps(_mm512_min_ps(_x, p16f_exp_hi), p16f_exp_lo);
// Express exp(x) as exp(m*ln(2) + r), start by extracting
// m = floor(x/ln(2) + 0.5).
__m512 m = _mm512_floor_ps(_mm512_fmadd_ps(x, p16f_cephes_LOG2EF, p16f_half));
// Get r = x - m*ln(2).
__m512 p16f_nln2 = _mm512_set1_ps(-0.6931471805599453f);
__m512 r = _mm512_fmadd_ps(m, p16f_nln2, x);
__m512 r2 = _mm512_mul_ps(r, r);
__m512 y = p16f_cephes_exp_p0;
y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p1);
y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p2);
y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p3);
y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p4);
y = _mm512_fmadd_ps(y, r, p16f_cephes_exp_p5);
y = _mm512_fmadd_ps(y, r2, r);
y = _mm512_add_ps(y, p16f_1);
// Build emm0 = 2^m.
__m512i emm0 = _mm512_cvttps_epi32(_mm512_add_ps(m, p16f_127));
emm0 = _mm512_slli_epi32(emm0, 23);
// Return 2^m * exp(r).
return _mm512_max_ps(_mm512_mul_ps(y, _mm512_castsi512_ps(emm0)), _x);
}
// need to do for res.
void softmax_sum_max(float* AB,
float* sum,
float* max,
float* pre_sum,
float* pre_max,
float refac,
int m,
int k) {
assert(k % 16 == 0);
float max_val = std::numeric_limits<float>::lowest();
__m512 vrefac = _mm512_set1_ps(refac);
for (int i = 0; i < m; ++i) {
float* buf = AB + i * k;
// max val for avoiding inf and nan
__m512 vmax = _mm512_set1_ps(max_val);
for (int off = 0; off < k; off += 16) {
int remain = k - off;
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
vmax = _mm512_mask_max_ps(vmax, mask, vmax, vx);
}
float _max = _mm512_reduce_max_ps(vmax);
_max *= refac;
_max = _max > max[i] ? _max : max[i];
__m512 merr = _mm512_set1_ps(max[i] - _max);
merr = vexp(merr);
max[i] = _max;
// exp and get sum
__m512 vsum = _mm512_set1_ps(0);
vmax = _mm512_set1_ps(_max);
for (int off = 0; off < k; off += 16) {
int remain = k - off;
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
vx = vexp(vx * vrefac - vmax);
_mm512_mask_storeu_ps(buf + off, mask, vx);
vsum = _mm512_mask_add_ps(vsum, mask, vsum, vx);
}
float _sum = _mm512_reduce_add_ps(vsum);
float fac = _mm512_cvtss_f32(merr);
sum[i] = sum[i] * fac + _sum;
_sum = sum[i];
// Compute exp/sum(exp) and store
__m512 vrsum = _mm512_set1_ps(1.0f / _sum);
for (int off = 0; off < k; off += 16) {
int remain = k - off;
__mmask16 mask = (remain >= 16 ? 0xffff : (1 << remain) - 1);
__m512 vx = _mm512_maskz_loadu_ps(mask, buf + off);
vx = vx * vrsum;
_mm512_mask_storeu_ps(buf + off, mask, vx);
}
}
}
void update_out_blk(float* output,
const float* exp_ABC,
float* pre_sum,
float* sum,
float* pre_max,
float* max,
int m,
int n) {
assert(n % 16 == 0);
for (int i = 0; i < m; ++i) {
const float* buf = exp_ABC + i * n;
float* outbuf = output + i * n;
__m512 merr = _mm512_set1_ps(pre_max[i] - max[i]);
merr = vexp(merr);
__m512 vfac = _mm512_set1_ps(pre_sum[i] / sum[i]);
for (int off = 0; off < n; off += 16) {
__m512 vout = _mm512_loadu_ps(outbuf + off);
__m512 vabc = _mm512_loadu_ps(buf + off);
__m512 vupt = vout * merr * vfac + vabc;
_mm512_storeu_ps(outbuf + off, vupt);
}
pre_sum[i] = sum[i];
pre_max[i] = max[i];
}
}
#endif
// hard code: axis = 1
// sum += sum(exp(A[i]))
// output = output * pre_sum / sum + (exp(A) / sum) x B
// pre_sum = sum
void incremental_tile_attention(const float* A,
const float* B,
const float* C,
int m,
int n,
int k,
float* pre_sum,
float* sum,
float* pre_max,
float* max,
float refac,
float* AB,
float* exp_ABC,
float* output) {
sgemm(A, B, AB, m, k, n, false, true);
softmax_sum_max(AB, sum, max, pre_sum, pre_max, refac, m, k);
sgemm(AB, C, exp_ABC, m, n, k, false, false);
update_out_blk(output, exp_ABC, pre_sum, sum, pre_max, max, m, n);
}
// scaled dot-product attention: bmm1 + softmax + bmm2
void scaled_dp_attention(const float* query,
const float* key,
const float* value,
float scale,
int batch_size,
int itsize,
int otsize,
int num_head,
int head_size,
float* output) {
// output = trans(softmax(query * trans(key)) * value)
int iblk = std::min(512, itsize / 1);
int oblk = std::min(512, otsize / 1);
float refac = scale;
assert(itsize % iblk == 0);
assert(otsize % oblk == 0);
#ifdef PADDLE_WITH_MKLML
int nth = omp_get_max_threads();
#else
int nth = 1;
#endif
float** pre_sum;
float** sum;
float** pre_max;
float** max;
float** qk_arr;
float** exp_qkv_arr;
pre_sum = new float*[nth];
sum = new float*[nth];
pre_max = new float*[nth];
max = new float*[nth];
qk_arr = new float*[nth];
exp_qkv_arr = new float*[nth];
for (int i = 0; i < nth; ++i) {
pre_sum[i] = new float[iblk];
sum[i] = new float[iblk];
pre_max[i] = new float[iblk];
max[i] = new float[iblk];
qk_arr[i] = new float[iblk * oblk];
exp_qkv_arr[i] = new float[iblk * head_size];
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for (int i = 0; i < batch_size; ++i) {
for (int j = 0; j < num_head; ++j) {
for (int m = 0; m < itsize; m += iblk) {
#ifdef PADDLE_WITH_MKLML
int tid = omp_get_thread_num();
#else
int tid = 0;
#endif
int ooffset =
i * num_head * otsize * head_size + j * otsize * head_size;
const float* k = key + ooffset;
const float* v = value + ooffset;
int q_rblk = std::min(iblk, itsize - m);
int ioffset =
i * num_head * otsize * head_size + j * otsize * head_size;
const float* q = query + ioffset + m * head_size;
float* out = output + ioffset + m * head_size;
// reset out
for (int ii = 0; ii < q_rblk; ++ii) {
#ifdef PADDLE_WITH_MKLML
#pragma omp simd
#endif
for (int jj = 0; jj < head_size; ++jj) {
out[ii * head_size + jj] = 0; // reset output
}
}
// reset sum
#ifdef PADDLE_WITH_MKLML
#pragma omp simd
#endif
for (int ii = 0; ii < q_rblk; ++ii) {
pre_sum[tid][ii] = 0;
sum[tid][ii] = 0;
pre_max[tid][ii] = std::numeric_limits<float>::lowest();
max[tid][ii] = std::numeric_limits<float>::lowest();
}
//
for (int b = 0; b < otsize; b += oblk) {
int kv_rblk = std::min(oblk, otsize - b);
const float* blk_k = k + b * head_size;
const float* blk_v = v + b * head_size;
incremental_tile_attention(q,
blk_k,
blk_v,
q_rblk,
head_size,
kv_rblk,
pre_sum[tid],
sum[tid],
pre_max[tid],
max[tid],
refac,
qk_arr[tid],
exp_qkv_arr[tid],
out);
}
}
}
}
for (int i = 0; i < nth; ++i) {
delete[] pre_sum[i];
delete[] sum[i];
delete[] pre_max[i];
delete[] max[i];
delete[] qk_arr[i];
delete[] exp_qkv_arr[i];
}
delete[] pre_sum;
delete[] sum;
delete[] pre_max;
delete[] max;
delete[] qk_arr;
delete[] exp_qkv_arr;
return;
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/self_dp_attention_op.h"
#include "paddle/fluid/operators/fused/scaled_dp_attention.h"
namespace paddle {
namespace operators {
void SelfDPAttenOp::InferShape(framework::InferShapeContext* ctx) const {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SelfDPAtten");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SelfDPAtten");
auto dim_input = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_input.size(),
5,
platform::errors::InvalidArgument(
"The size of input X dims should be 5, "
"[batchsize, tokensize, 3, nhead, headsize] "
", but now Input X dim is:[%s] ",
dim_input));
PADDLE_ENFORCE_EQ(dim_input[4] % 16,
0,
platform::errors::InvalidArgument(
"The last dim of input X should be a multiple of 16, "
", but now the dim is:[%d] "
"Please remove self_attention_fuse_pass from the lists",
dim_input[4]));
framework::DDim out_dims(
{dim_input[0], dim_input[1], dim_input[3], dim_input[4]});
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
phi::KernelKey SelfDPAttenOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
void SelfDPAttenOpMaker::Make() {
AddInput("X", "(LoDTensor) Input tensors of this operator.");
AddOutput("Out", "(LoDTensor) Output tensor of this operator.");
AddAttr<float>("alpha", "The scale of Out").SetDefault(1.0f);
AddAttr<int>("head_number", "The number of heads of the matrix")
.SetDefault(1);
AddComment(R"DOC(
Multihead Self-scaled-dp-Attention Operator.
)DOC");
}
template <typename T>
class SelfDPAttenKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using DeviceContext = phi::CPUContext;
auto* in = ctx.Input<Tensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto place = ctx.GetPlace();
auto* input_d = in->data<T>();
auto* output_d = out->mutable_data<T>(place);
float scale = static_cast<float>(ctx.Attr<float>("alpha"));
int head_number = ctx.Attr<int>("head_number");
auto input_dims = in->dims();
// in shouble be (batch * seq * 3 * head_num * head_size)
// out shouble be (batch * seq * head_num * head_size)
int batch_size = input_dims[0];
int seq_len = input_dims[1];
int head_size = input_dims[4];
auto& dev_ctx = ctx.template device_context<DeviceContext>();
phi::DenseTensor temp1 =
ctx.AllocateTmpTensor<T, DeviceContext>(input_dims, dev_ctx);
float* trans_input = temp1.mutable_data<float>(place);
phi::DenseTensor temp2 =
ctx.AllocateTmpTensor<T, DeviceContext>(input_dims, dev_ctx);
float* trans_output = temp2.mutable_data<float>(place);
transpose_before_bmm1<T, float>(
input_d, trans_input, batch_size, seq_len, head_number, head_size);
float* query = trans_input;
float* key = trans_input + batch_size * head_number * seq_len * head_size;
float* value =
trans_input + batch_size * head_number * seq_len * head_size * 2;
scaled_dp_attention(query,
key,
value,
scale,
batch_size,
seq_len,
seq_len,
head_number,
head_size,
trans_output);
transpose_after_bmm2<float, T>(
trans_output, output_d, batch_size, seq_len, head_number, head_size);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(self_dp_attention,
ops::SelfDPAttenOp,
ops::SelfDPAttenOpMaker);
REGISTER_OP_KERNEL(self_dp_attention,
CPU,
phi::CPUPlace,
ops::SelfDPAttenKernel<float>,
ops::SelfDPAttenKernel<double>);
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using LoDTensor = phi::DenseTensor;
using Tensor = phi::DenseTensor;
class SelfDPAttenOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
};
class SelfDPAttenOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.incubate
from paddle.fluid import core
paddle.enable_static()
np.random.seed(0)
def test_fuse_resenet_unit():
place = paddle.CPUPlace()
program = paddle.static.Program()
startup_program = paddle.static.Program()
batch_size = 1
token_size = 4097
hidden_size = 768
num_heads = 12
dtype = np.float32
with paddle.static.program_guard(program, startup_program):
x = paddle.static.data(
"x", [batch_size, token_size, hidden_size * 3], dtype=dtype
)
qkv = x.reshape(
(batch_size, token_size, 3, num_heads, hidden_size // num_heads)
).transpose((2, 0, 3, 1, 4))
q, k, v = qkv[0], qkv[1], qkv[2]
attn = q.matmul(k.transpose((0, 1, 3, 2)))
attn = paddle.nn.functional.softmax(attn, axis=-1)
out = (
(attn.matmul(v))
.transpose((0, 2, 1, 3))
.reshape((-1, token_size, hidden_size))
)
graph = core.Graph(program.desc)
core.get_pass("self_attention_fuse_pass").apply(graph)
after_program = paddle.fluid.framework.IrGraph(graph).to_program()
exe = paddle.static.Executor(place)
exe.run(startup_program)
feed = {
"x": np.random.randn(batch_size, token_size, hidden_size * 3).astype(
dtype
)
}
before_out = exe.run(program, feed=feed, fetch_list=[out.name])
after_out = exe.run(after_program, feed=feed, fetch_list=[out.name])
np.testing.assert_allclose(
before_out[0], after_out[0], rtol=1e-05, atol=0.005
)
if __name__ == '__main__':
test_fuse_resenet_unit()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册