未验证 提交 d881d690 编写于 作者: R RichardWooSJTU 提交者: GitHub

add fused token prune op and plugin (#44281)

* add fused token prune op and plugin
上级 d2e59e15
......@@ -2089,6 +2089,7 @@ USE_TRT_CONVERTER(top_k)
USE_TRT_CONVERTER(top_k_v2)
USE_TRT_CONVERTER(squeeze2)
USE_TRT_CONVERTER(unsqueeze2)
USE_TRT_CONVERTER(fused_token_prune)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
......
......@@ -68,7 +68,8 @@ list(
c_allreduce_op.cc
top_k_op.cc
squeeze2_op.cc
unsqueeze2_op.cc)
unsqueeze2_op.cc
fused_token_prune_op.cc)
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND CONVERT_FILES sparse_fc_op.cc sparse_multihead_matmul_op.cc)
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
class FusedTokenPruneOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
framework::OpDesc op_desc(op, nullptr);
nvinfer1::ILayer* layer = nullptr;
auto* Attn = engine_->GetITensor(op_desc.Input("Attn").front());
auto* X = engine_->GetITensor(op_desc.Input("X").front());
auto* Mask = engine_->GetITensor(op_desc.Input("Mask").front());
auto* NewMask = engine_->GetITensor(op_desc.Input("NewMask").front());
bool keep_first_token =
op_desc.HasAttr("keep_first_token")
? BOOST_GET_CONST(bool, op_desc.GetAttr("keep_first_token"))
: true;
bool keep_order = op_desc.HasAttr("keep_order")
? BOOST_GET_CONST(bool, op_desc.GetAttr("keep_order"))
: false;
std::vector<nvinfer1::ITensor*> itensors = {Attn, X, Mask, NewMask};
auto output_name = op_desc.Output("SlimmedX")[0];
auto out_inds_name = op_desc.Output("CLSInds")[0];
if (engine_->with_dynamic_shape()) {
#if IS_TRT_VERSION_GE(6000)
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
}
plugin::FusedTokenPrunePluginDynamic* plugin =
new plugin::FusedTokenPrunePluginDynamic(
with_fp16, keep_first_token, keep_order);
layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
#else
PADDLE_THROW(platform::errors::Fatal(
"You are running the TRT Dynamic Shape mode, need to confirm that "
"your TRT version is no less than 6.0"));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"You are running the Ernie(Bert) model in static shape mode, which "
"is not supported for the time being.\n"
"You can use the config.SetTRTDynamicShapeInfo(...) interface to set "
"the shape information to run the dynamic shape mode."));
}
RreplenishLayerAndOutput(
layer, "fused_token_prune", {output_name, out_inds_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(fused_token_prune, FusedTokenPruneOpConverter);
......@@ -275,7 +275,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"recover_padding",
"remove_padding",
"squeeze2",
"unsqueeze2"};
"unsqueeze2",
"fused_token_prune"};
};
bool OpTeller::Tell(const framework::ir::Node* node,
......
......@@ -29,7 +29,8 @@ list(
remove_padding_plugin.cu
recover_padding_plugin.cu
c_allreduce_op_plugin.cu
preln_residual_bias_plugin.cu)
preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu)
if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
list(APPEND TRT_FILES spmm_plugin.cu)
......@@ -44,3 +45,10 @@ nv_test(
test_split_plugin
SRCS test_split_plugin.cc
DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_plugin)
if(NOT WIN32)
nv_test(
test_fused_token_prune_plugin
SRCS test_fused_token_prune_plugin.cc
DEPS paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_plugin)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "cub/cub.cuh"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h"
#include "paddle/fluid/operators/fused_token_prune_op.cu.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
template <typename T>
__global__ void ElementwiseMask(const T* a,
const T* b,
T* res,
int num_elements) {
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_elements) return;
const T zero = 0;
res[tid] = b[tid] >= zero ? a[tid] : zero;
}
template <typename T>
__global__ void FillZero(T* data, int len) {
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= len) return;
const T zero = 0;
data[tid] = zero;
}
__global__ void FillIndex(int32_t* indices, int num_raws, int num_cols) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_raws * num_cols) return;
int col = tid % num_cols;
int raw = tid / num_cols;
indices[tid] = col;
}
template <typename T>
__global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) {
auto raw = blockIdx.x * blockDim.x + threadIdx.x;
if (raw >= num_raws) return;
mat[raw * num_cols] = max_value;
}
__global__ void FillOffsets(int* offsets, int num_raws, int num_cols) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid > num_raws) return;
offsets[tid] = tid * num_cols;
}
template <typename T>
__global__ void Slice(
const T* src, T* dst, int num_raws, int src_num_cols, int dst_num_cols) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_raws * dst_num_cols) return;
int raw = tid / dst_num_cols;
int col = tid % dst_num_cols;
dst[tid] = src[raw * src_num_cols + col];
}
template <typename T>
__global__ void ReduceSum2(
const T* src, T* dst, int bsz, int nb_head, int max_seq_len) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
int batch = bid / (nb_head * num_blocks_per_head);
int col = bid % max_seq_len;
int head = (bid / num_blocks_per_head) % nb_head;
extern __shared__ T res_float[];
res_float[tid] =
src[batch * (nb_head * max_seq_len * max_seq_len) +
head * (max_seq_len * max_seq_len) + col + tid * max_seq_len];
__syncthreads();
for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
if (tid < offset) {
res_float[tid] += res_float[tid + offset];
}
__syncthreads();
if (offset % 2 == 1 && tid == offset - 2) {
res_float[tid] += res_float[tid + 1];
}
}
if (tid == 0) {
auto* dst_addr = dst + batch * max_seq_len + col;
atomicAdd(dst_addr, res_float[0]);
}
}
template <>
__global__ void ReduceSum2<half>(
const half* src, half* dst, int bsz, int nb_head, int max_seq_len) {
int tid = threadIdx.x;
int bid = blockIdx.x;
int num_blocks_per_head = ((max_seq_len / blockDim.x) * max_seq_len);
int batch = bid / (nb_head * num_blocks_per_head);
int col = bid % max_seq_len;
int head = (bid / num_blocks_per_head) % nb_head;
extern __shared__ half res_half[];
res_half[tid] =
src[batch * (nb_head * max_seq_len * max_seq_len) +
head * (max_seq_len * max_seq_len) + col + tid * max_seq_len];
__syncthreads();
for (int offset = blockDim.x >> 1; offset > 0; offset >>= 1) {
if (tid < offset) {
res_half[tid] += res_half[tid + offset];
}
__syncthreads();
if (offset % 2 == 1 && tid == offset - 2) {
res_half[tid] += res_half[tid + 1];
}
__syncthreads();
}
if (tid == 0) {
platform::fastAtomicAdd<platform::float16>(
reinterpret_cast<platform::float16*>(dst),
static_cast<size_t>(batch * max_seq_len + col),
static_cast<size_t>(bsz * max_seq_len),
static_cast<platform::float16>(res_half[0]));
}
}
template <typename T>
__global__ void TakeAlongAxis(const T* src,
T* dst,
int32_t* indices,
int num_raws,
int src_num_cols,
int dst_num_cols,
int num_elements) {
int tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= num_raws * dst_num_cols) return;
int raw = tid / dst_num_cols;
int col = tid % dst_num_cols;
for (int i = 0; i < num_elements; ++i) {
dst[tid * num_elements + i] =
*(src + (raw * src_num_cols + indices[tid]) * num_elements + i);
}
}
nvinfer1::DimsExprs FusedTokenPrunePluginDynamic::getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) TRT_NOEXCEPT {
auto x_dims = inputs[1], new_mask_dims = inputs[3];
if (output_index == 0) {
nvinfer1::DimsExprs ret = x_dims;
ret.d[1] = new_mask_dims.d[2];
return ret;
} else {
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = new_mask_dims.d[0];
ret.d[1] = new_mask_dims.d[2];
return ret;
}
}
bool FusedTokenPrunePluginDynamic::supportsFormatCombination(
int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT {
PADDLE_ENFORCE_NOT_NULL(
in_out,
platform::errors::InvalidArgument(
"The input of swish plugin shoule not be nullptr."));
PADDLE_ENFORCE_LT(
pos,
nb_inputs + nb_outputs,
platform::errors::InvalidArgument("The pos(%d) should be less than the "
"num(%d) of the input and the output.",
pos,
nb_inputs + nb_outputs));
const nvinfer1::PluginTensorDesc& in = in_out[pos];
if (pos == 0) {
if (with_fp16_) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
return (in.type == nvinfer1::DataType::kFLOAT ||
in.type == nvinfer1::DataType::kHALF) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#else
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
#endif
} else {
return (in.type == nvinfer1::DataType::kFLOAT) &&
(in.format == nvinfer1::TensorFormat::kLINEAR);
}
} else if (pos <= 4) {
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
return in.type == prev.type && in.format == prev.format;
} else {
const nvinfer1::PluginTensorDesc& prev = in_out[pos - 1];
return in.type == nvinfer1::DataType::kINT32 && in.format == prev.format;
}
}
nvinfer1::DataType FusedTokenPrunePluginDynamic::getOutputDataType(
int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const TRT_NOEXCEPT {
if (index == 0) {
return input_types[1];
} else if (index == 1) {
return nvinfer1::DataType::kINT32;
}
}
size_t FusedTokenPrunePluginDynamic::getWorkspaceSize(
const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const TRT_NOEXCEPT {
auto attn_dims = inputs[0].dims;
auto x_dims = inputs[1].dims;
auto new_mask_dims = inputs[3].dims;
auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1],
max_seq_len = attn_dims.d[2];
int slimmed_x_len = new_mask_dims.d[2];
int total = bsz * nb_head * max_seq_len * max_seq_len;
size_t size = total * sizeof(float);
size += bsz * max_seq_len * sizeof(float);
size += bsz * max_seq_len * sizeof(int32_t);
size += bsz * max_seq_len * sizeof(float);
size += bsz * max_seq_len * sizeof(int32_t);
size += (bsz + 1) * sizeof(int);
size += bsz * slimmed_x_len * sizeof(int32_t);
return size;
}
template <typename T>
int FusedTokenPrunePluginDynamic::enqueueImpl(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace_ptr,
cudaStream_t stream,
int device_id,
T max_value) {
// Dims
auto attn_dims = input_desc[0].dims;
auto x_dims = input_desc[1].dims;
auto new_mask_dims = input_desc[3].dims;
auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1],
max_seq_len = attn_dims.d[2];
auto c = x_dims.d[2];
auto slimmed_x_len = new_mask_dims.d[2];
// Inputs
const T* attn_data = static_cast<const T*>(inputs[0]);
const T* x_data = static_cast<const T*>(inputs[1]);
const T* mask_data = static_cast<const T*>(inputs[2]);
// Outputs
T* output_data = static_cast<T*>(outputs[0]);
int32_t* output_indices_data = static_cast<int32_t*>(outputs[1]);
int total = bsz * nb_head * max_seq_len * max_seq_len;
int block = operators::ComputeBlockSize(max_seq_len);
int grid = operators::CeilDivide(total, block);
// Workspace for intermediate variable
char* workspace = static_cast<char*>(workspace_ptr);
T* attn_tmp_data = reinterpret_cast<T*>(workspace);
size_t offset = total * sizeof(T);
T* attn_accu_data = reinterpret_cast<T*>(workspace + offset);
offset += bsz * max_seq_len * sizeof(T);
int32_t* attn_accu_indices_data =
reinterpret_cast<int32_t*>(workspace + offset);
offset += bsz * max_seq_len * sizeof(int32_t);
T* sort_attn_accu_data = reinterpret_cast<T*>(workspace + offset);
offset += bsz * max_seq_len * sizeof(T);
int32_t* sort_attn_accu_indices_data =
reinterpret_cast<int32_t*>(workspace + offset);
offset += bsz * max_seq_len * sizeof(int32_t);
int* offsets_data = reinterpret_cast<int*>(workspace + offset);
offset += (bsz + 1) * sizeof(int);
int32_t* slimmed_sort_attn_accu_indices_data =
reinterpret_cast<int32_t*>(workspace + offset);
// 1. Filter attn by mask
ElementwiseMask<T>
<<<grid, block, 0, stream>>>(attn_data, mask_data, attn_tmp_data, total);
total = bsz * max_seq_len;
block = operators::ComputeBlockSize(max_seq_len);
grid = operators::CeilDivide(total, block);
FillZero<T><<<grid, block, 0, stream>>>(attn_accu_data, total);
// 2. Reduce sum
total = bsz * nb_head * max_seq_len * max_seq_len;
int block_tmp = max_seq_len;
while (block_tmp > 1024)
block_tmp /= 2; // if max seq len > 1024, it must be 2^n
block =
block_tmp; // make sure max_seq_len is an integral multiple of block_size
grid = operators::CeilDivide(total, block);
ReduceSum2<T><<<grid, block, block * sizeof(T), stream>>>(
attn_tmp_data, attn_accu_data, bsz, nb_head, max_seq_len);
// 3. Prepare token indices
total = bsz * max_seq_len;
block = operators::ComputeBlockSize(max_seq_len);
grid = operators::CeilDivide(total, block);
FillIndex<<<grid, block, 0, stream>>>(
attn_accu_indices_data, bsz, max_seq_len);
// 4. Sort token indices by attn
if (keep_first_token_) {
MaximumFirst<T>
<<<bsz, 1, 0, stream>>>(attn_accu_data, bsz, max_seq_len, max_value);
}
size_t temp_storage_bytes = -1;
int num_items = bsz * max_seq_len;
int num_segments = bsz;
FillOffsets<<<bsz + 1, 1, 0, stream>>>(offsets_data, bsz, max_seq_len);
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr,
temp_storage_bytes,
attn_accu_data,
sort_attn_accu_data,
attn_accu_indices_data,
sort_attn_accu_indices_data,
num_items,
num_segments,
offsets_data,
offsets_data + 1,
0,
sizeof(T) * 8,
stream));
int64_t temp_size = temp_storage_bytes;
framework::Tensor temp_storage;
auto* temp_storage_data = temp_storage.mutable_data<uint8_t>(
{temp_size}, platform::CUDAPlace(device_id));
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage_data,
temp_storage_bytes,
attn_accu_data,
sort_attn_accu_data,
attn_accu_indices_data,
sort_attn_accu_indices_data,
num_items,
num_segments,
offsets_data,
offsets_data + 1,
0,
sizeof(T) * 8,
stream));
// 5. Slice
total = bsz * slimmed_x_len;
block = operators::ComputeBlockSize(slimmed_x_len);
grid = operators::CeilDivide(total, block);
Slice<int32_t>
<<<grid, block, 0, stream>>>(sort_attn_accu_indices_data,
slimmed_sort_attn_accu_indices_data,
bsz,
max_seq_len,
slimmed_x_len);
if (keep_order_) {
// 6. reorder
num_items = bsz * slimmed_x_len;
FillOffsets<<<bsz + 1, 1, 0, stream>>>(offsets_data, bsz, slimmed_x_len);
temp_storage_bytes = -1;
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr,
temp_storage_bytes,
slimmed_sort_attn_accu_indices_data,
output_indices_data,
num_items,
num_segments,
offsets_data,
offsets_data + 1,
0,
sizeof(int32_t) * 8,
stream));
temp_size = temp_storage_bytes;
temp_storage.Resize({temp_size});
temp_storage_data =
temp_storage.mutable_data<uint8_t>(platform::CUDAPlace(device_id));
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys(
temp_storage_data,
temp_storage_bytes,
slimmed_sort_attn_accu_indices_data,
output_indices_data,
num_items,
num_segments,
offsets_data,
offsets_data + 1,
0,
sizeof(int32_t) * 8,
stream));
TakeAlongAxis<T><<<grid, block, 0, stream>>>(x_data,
output_data,
output_indices_data,
bsz,
max_seq_len,
slimmed_x_len,
c);
} else {
PADDLE_ENFORCE_GPU_SUCCESS(cudaMemcpy(output_indices_data,
slimmed_sort_attn_accu_indices_data,
bsz * slimmed_x_len * sizeof(int32_t),
cudaMemcpyDeviceToDevice));
TakeAlongAxis<T>
<<<grid, block, 0, stream>>>(x_data,
output_data,
slimmed_sort_attn_accu_indices_data,
bsz,
max_seq_len,
slimmed_x_len,
c);
}
return cudaGetLastError() != cudaSuccess;
}
int FusedTokenPrunePluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT {
auto input_type = input_desc[0].type;
auto attn_dims = input_desc[0].dims;
auto bsz = attn_dims.d[0], nb_head = attn_dims.d[1],
max_seq_len = attn_dims.d[2];
int device_id;
cudaGetDevice(&device_id);
if (input_type == nvinfer1::DataType::kFLOAT) {
VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp32";
float max = std::numeric_limits<float>::max();
return enqueueImpl<float>(input_desc,
output_desc,
inputs,
outputs,
workspace,
stream,
device_id,
max);
} else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. FusedTokenPrune-->fp16";
half max = 65504.0;
return enqueueImpl<half>(input_desc,
output_desc,
inputs,
outputs,
workspace,
stream,
device_id,
max);
#else
PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) TensorRT Plugin should be "
"complied with CUDA version >= 10.0 when running with fp16. "
"Please recomplie it or try to use fp32 by set "
"config.SetTRTDynamicShapeInfo(min_input_shape, "
"max_input_shape, opt_input_shape, true"));
#endif
} else {
PADDLE_THROW(
platform::errors::Fatal("The FusedTokenPrune TRT Plugin's input type "
"should be float or half."));
}
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class FusedTokenPrunePluginDynamic : public DynamicPluginTensorRT {
public:
explicit FusedTokenPrunePluginDynamic(bool with_fp16,
bool keep_first_token,
bool keep_order)
: keep_first_token_(keep_first_token), keep_order_(keep_order) {
with_fp16_ = with_fp16;
}
FusedTokenPrunePluginDynamic(void const* serial_data, size_t serial_length) {
DeserializeValue(&serial_data, &serial_length, &with_fp16_);
DeserializeValue(&serial_data, &serial_length, &keep_first_token_);
DeserializeValue(&serial_data, &serial_length, &keep_order_);
}
nvinfer1::IPluginV2DynamicExt* clone() const TRT_NOEXCEPT override {
return new FusedTokenPrunePluginDynamic(
with_fp16_, keep_first_token_, keep_order_);
}
const char* getPluginType() const TRT_NOEXCEPT override {
return "fused_token_prune_plugin_dynamic";
}
int getNbOutputs() const TRT_NOEXCEPT override { return 2; }
int initialize() TRT_NOEXCEPT override { return 0; }
size_t getSerializationSize() const TRT_NOEXCEPT override {
return SerializedSize(with_fp16_) + SerializedSize(keep_first_token_) +
SerializedSize(keep_order_);
}
void serialize(void* buffer) const TRT_NOEXCEPT override {
SerializeValue(&buffer, with_fp16_);
SerializeValue(&buffer, keep_first_token_);
SerializeValue(&buffer, keep_order_);
}
nvinfer1::DimsExprs getOutputDimensions(
int output_index,
const nvinfer1::DimsExprs* inputs,
int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) // NOLINT
TRT_NOEXCEPT override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs,
int nb_outputs) TRT_NOEXCEPT override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const TRT_NOEXCEPT override;
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) TRT_NOEXCEPT override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const
TRT_NOEXCEPT override;
void destroy() TRT_NOEXCEPT override { delete this; }
private:
template <typename T>
int enqueueImpl(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream,
int device_id,
T max_value);
bool keep_first_token_;
bool keep_order_;
};
class FusedTokenPrunePluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
FusedTokenPrunePluginDynamicCreator() {}
const char* getPluginName() const TRT_NOEXCEPT override {
return "fused_token_prune_plugin_dynamic";
}
const char* getPluginVersion() const TRT_NOEXCEPT override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() TRT_NOEXCEPT override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(const char* name,
const nvinfer1::PluginFieldCollection* fc)
TRT_NOEXCEPT override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length)
TRT_NOEXCEPT override {
auto plugin = new FusedTokenPrunePluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) TRT_NOEXCEPT override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const TRT_NOEXCEPT override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_;
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(FusedTokenPrunePluginDynamicCreator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
TEST(fused_token_prune_op_plugin, test_plugin) {
FusedTokenPrunePluginDynamic plugin(
true, /*keep_first_token*/ false, /*keep_order*/ true);
plugin.configurePlugin(nullptr, 4, nullptr, 2);
plugin.initialize();
plugin.getPluginType();
plugin.getNbOutputs();
auto clone_plugin = plugin.clone();
clone_plugin->destroy();
size_t buf_size = plugin.getSerializationSize();
std::vector<char> buf(buf_size);
plugin.serialize(buf.data());
}
TEST(fused_token_prune_op_plugin, test_plugin_creater) {
FusedTokenPrunePluginDynamicCreator creator;
creator.getFieldNames();
creator.createPlugin("test", nullptr);
creator.setPluginNamespace("test");
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -22,6 +22,7 @@ limitations under the License. */
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
#include "paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h"
#endif
#include "paddle/fluid/inference/tensorrt/plugin/fused_token_prune_op_plugin.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/float16.h"
......@@ -195,6 +196,197 @@ TEST_F(TensorRTDynamicEngineTest, test_spmm) {
return;
}
class TensorRTDynamicTestFusedTokenPrune : public ::testing::Test {
protected:
void SetUp() override {
ctx_ = new platform::CUDADeviceContext(platform::CUDAPlace(0));
ctx_->SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(platform::CUDAPlace(0), ctx_->stream())
.get());
ctx_->SetHostAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
ctx_->SetZeroAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetZeroAllocator(platform::CUDAPlace(0))
.get());
ctx_->SetPinnedAllocator(
paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CUDAPinnedPlace())
.get());
ctx_->PartialInitWithAllocator();
std::map<std::string, std::vector<int>> min_input_shape = {
{"attn", {4, 1, 4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> max_input_shape = {
{"attn", {4, 1, 4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
std::map<std::string, std::vector<int>> optim_input_shape = {
{"attn", {4, 1, 4, 4}},
{"x", {4, 4, 1}},
{"mask", {4, 1, 4, 4}},
{"new_mask", {4, 1, 2, 2}}};
engine_ = new TensorRTEngine(16,
1 << 10,
AnalysisConfig::Precision::kHalf,
nullptr,
0,
min_input_shape,
max_input_shape,
optim_input_shape,
false,
phi::DataType::FLOAT32,
NaiveLogger::Global());
engine_->InitNetwork();
}
void TearDown() override {
if (engine_) {
delete engine_;
engine_ = nullptr;
}
}
void PrepareInputOutput(const std::vector<std::vector<float16>> inputs,
std::vector<std::vector<int>> output_shapes) {
LOG(INFO) << "PrepareInputOutput";
int num_inputs = inputs.size();
int num_outputs = output_shapes.size();
inputs_.resize(num_inputs);
outputs_.resize(num_outputs);
for (int i = 0; i < num_inputs; ++i) {
paddle::framework::TensorFromVector(inputs[i], *ctx_, &inputs_[i]);
}
for (int i = 0; i < num_outputs; ++i) {
outputs_[i].Resize(phi::make_ddim(output_shapes[i]));
}
}
void GetOutput(std::vector<float> &slimmed_x, // NOLINT
std::vector<int32_t> &cls_inds) { // NOLINT
paddle::framework::TensorToVector(outputs_[0], *ctx_, &slimmed_x);
paddle::framework::TensorToVector(outputs_[1], *ctx_, &cls_inds);
}
protected:
std::vector<framework::Tensor> inputs_;
std::vector<framework::Tensor> outputs_;
TensorRTEngine *engine_;
platform::CUDADeviceContext *ctx_;
};
TEST_F(TensorRTDynamicTestFusedTokenPrune, test_fused_token_prune) {
#if IS_TRT_VERSION_GE(8000)
auto *attn = engine_->DeclareInput(
"attn", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4});
auto *x = engine_->DeclareInput(
"x", nvinfer1::DataType::kHALF, nvinfer1::Dims3{-1, 4, 1});
auto *mask = engine_->DeclareInput(
"mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 4, 4});
auto *new_mask = engine_->DeclareInput(
"new_mask", nvinfer1::DataType::kHALF, nvinfer1::Dims4{-1, 1, 2, 2});
plugin::FusedTokenPrunePluginDynamic *plugin =
new plugin::FusedTokenPrunePluginDynamic(
true, /*keep_first_token*/ false, /*keep_order*/ true);
std::vector<nvinfer1::ITensor *> itensors = {attn, x, mask, new_mask};
auto *layer = engine_->AddDynamicPlugin(itensors.data(), 4, plugin);
PADDLE_ENFORCE_NOT_NULL(layer,
platform::errors::InvalidArgument(
"TRT fused_token_prune layer building failed."));
std::vector<std::string> output_tensor_names{"out_slimmed_x", "out_cls_inds"};
for (size_t i = 0; i < 2; i++) {
layer->getOutput(i)->setName(output_tensor_names[i].c_str());
engine_->DeclareOutput(layer, i, output_tensor_names[i]);
}
engine_->FreezeNetwork();
ASSERT_EQ(engine_->engine()->getNbBindings(), 6);
LOG(INFO) << "create input";
std::vector<float16> attn_v(64);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 4; ++k) {
attn_v[i * 16 + j * 4 + k] = k;
}
}
}
std::vector<float16> x_v(16);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
x_v[i * 4 + j] = 1;
}
}
std::vector<float16> mask_v(64);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
for (int k = 0; k < 4; ++k) {
mask_v[i * 16 + j * 4 + k] = 1;
}
}
}
std::vector<float16> new_mask_v(16);
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 2; ++j) {
for (int k = 0; k < 2; ++k) {
new_mask_v[i * 4 + j * 2 + k] = 1;
}
}
}
LOG(INFO) << "create output";
std::vector<int> out_slimmed_x_shape{4, 2, 1};
std::vector<int> out_cls_ins_shape{4, 2};
PrepareInputOutput({attn_v, x_v, mask_v, new_mask_v},
{out_slimmed_x_shape, out_cls_ins_shape});
auto *attn_gpu_data = inputs_[0].mutable_data<float16>(ctx_->GetPlace());
auto *x_gpu_data = inputs_[1].mutable_data<float16>(ctx_->GetPlace());
auto *mask_gpu_data = inputs_[2].mutable_data<float16>(ctx_->GetPlace());
auto *new_mask_gpu_data = inputs_[3].mutable_data<float16>(ctx_->GetPlace());
auto *slimmed_x_gpu_data = outputs_[0].mutable_data<float>(ctx_->GetPlace());
auto *cls_inds_gpu_data = outputs_[1].mutable_data<int32_t>(ctx_->GetPlace());
LOG(INFO) << "create buffers";
std::vector<void *> buffers(6);
buffers[0] = reinterpret_cast<void *>(attn_gpu_data);
buffers[1] = reinterpret_cast<void *>(x_gpu_data);
buffers[2] = reinterpret_cast<void *>(mask_gpu_data);
buffers[3] = reinterpret_cast<void *>(new_mask_gpu_data);
buffers[4] = reinterpret_cast<void *>(slimmed_x_gpu_data);
buffers[5] = reinterpret_cast<void *>(cls_inds_gpu_data);
LOG(INFO) << "Execute";
engine_->Execute(4, &buffers, ctx_->stream());
std::vector<float> slimmed_x_v;
std::vector<int32_t> cls_inds_v;
LOG(INFO) << "GetOutput";
GetOutput(slimmed_x_v, cls_inds_v);
ASSERT_EQ(cls_inds_v[0], 2);
ASSERT_EQ(cls_inds_v[1], 3);
ASSERT_EQ(cls_inds_v[2], 2);
ASSERT_EQ(cls_inds_v[3], 3);
ASSERT_EQ(cls_inds_v[4], 2);
ASSERT_EQ(cls_inds_v[5], 3);
ASSERT_EQ(cls_inds_v[6], 2);
ASSERT_EQ(cls_inds_v[7], 3);
LOG(INFO) << "finish";
#endif
}
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class FusedTokenPruneOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Attn",
"(Tensor)"
"The input of fused_token_prune op, whose shape should be [bsz, "
"num_head, max_seq_len, max_seq_len] and dtype should be "
"float32/float64,"
"Attn is attention scores of input sequences which will be used "
"to sort another input tensor: X's indices so that "
"some elements of X with lower attention score will not be "
"considered after this op.");
AddInput("X",
"(Tensor)"
"The input of fused_token_prune op, whose shape should be [bsz, "
"max_seq_len, c] and dtype should be float32/float64.");
AddInput(
"Mask",
"(Tensor)"
"The input of fused_token_prune op, whose shape should be [bsz, "
"num_head, "
"max_seq_len, max_seq_len] and dtype should be float32/float64."
"Mask is corresponding to Attn's elemnts one by one. Elements of Attn "
"will be set to zero if their corresponding mask is smaller than 0."
"This process happens before sorting X by attn.");
AddInput("NewMask",
"(Tensor)"
"The input of fused_token_prune op, whose shape should be [bsz, "
"num_head, slimmed_seq_len, slimmed_seq_len]."
"NewMask is just used to get slimmed_seq_len, so the value of "
"this input is not important in this op.");
AddOutput("SlimmedX",
"(Tensor)"
"The output of fused_token_prune op, whose shape should be [bsz, "
"slimmed_seq_len, C]."
"The tokens of X will be sorted by Attn firstly and then the "
"last (max_seq_len - slimmed_seq_len)"
"tokens will be deleted. SlimmedX is the remainning part of X. "
"");
AddOutput(
"CLSInds",
"(Tensor)"
"The output of fused_token_prune op, whose shape should be [bsz, "
"slimmed_seq_len] and dtype is int64. CLSInds contains token indices "
" of each batch after sorting and pruning. ");
AddAttr<bool>("keep_first_token",
"If keep_first_token is True, the element located in "
"CLSInds[:, 1] must be 0.")
.SetDefault(true);
AddAttr<bool>("keep_order",
"If keep_order is True, the relative order of SlimmedX and "
"CLSInds remains unchanged")
.SetDefault(false);
AddComment(R"DOC(
fused_token_prune op is used to fuse multiple ops to perform token pruning.
In this op:
1. Elements of Attn will be set to zero if their corresponding mask is smaller than 0.
2. The second dimension of X will be sorted by Attn.
3. The last (max_seq_len - slimmed_seq_len) lines of X will be pruned.
4. The remainning part of sorted X will output.
)DOC");
}
};
class FusedTokenPruneOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Attn"), "Input", "Attn", "FusedTokenPrune");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FusedTokenPrune");
OP_INOUT_CHECK(ctx->HasInput("Mask"), "Input", "Mask", "FusedTokenPrune");
OP_INOUT_CHECK(
ctx->HasInput("NewMask"), "Input", "NewMask", "FusedTokenPrune");
OP_INOUT_CHECK(
ctx->HasOutput("SlimmedX"), "Output", "SlimmedX", "FusedTokenPrune");
OP_INOUT_CHECK(
ctx->HasOutput("CLSInds"), "Output", "CLSInds", "FusedTokenPrune");
auto mask_dim = ctx->GetInputDim("Mask");
auto attn_dim = ctx->GetInputDim("Attn");
auto x_dim = ctx->GetInputDim("X");
auto new_mask_dim = ctx->GetInputDim("NewMask");
// check input dims number
PADDLE_ENFORCE_EQ(mask_dim.size(),
4,
platform::errors::InvalidArgument(
"The input mask must be 4-dimention"));
PADDLE_ENFORCE_EQ(attn_dim.size(),
4,
platform::errors::InvalidArgument(
"The input attn must be 4-dimention"));
PADDLE_ENFORCE_EQ(
x_dim.size(),
3,
platform::errors::InvalidArgument("The input x must be 4-dimention"));
PADDLE_ENFORCE_EQ(new_mask_dim.size(),
4,
platform::errors::InvalidArgument(
"The input attn must be 4-dimention"));
// check input dims relations
PADDLE_ENFORCE_EQ(mask_dim[0],
attn_dim[0],
platform::errors::InvalidArgument(
"The first dim of mask and attn should be the same"
"which is batch size"));
PADDLE_ENFORCE_EQ(mask_dim[1],
attn_dim[1],
platform::errors::InvalidArgument(
"The second dim of mask and attn should be the same"
"which is nb_head"));
PADDLE_ENFORCE_EQ(mask_dim[0],
x_dim[0],
platform::errors::InvalidArgument(
"The first dim of mask and x should be the same"
"which is batch size"));
PADDLE_ENFORCE_EQ(
mask_dim[2],
mask_dim[3],
platform::errors::InvalidArgument(
"The third dim and the fourth dim of mask should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(
attn_dim[2],
attn_dim[3],
platform::errors::InvalidArgument(
"The third dim and the fourth dim of mask should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(attn_dim[2],
mask_dim[2],
platform::errors::InvalidArgument(
"The third dim of mask and attn should be the same"
"which is max seq len"));
PADDLE_ENFORCE_EQ(attn_dim[2],
x_dim[1],
platform::errors::InvalidArgument(
"The third dim of mask and the second dim of attn"
"should be the same which is max seq len"));
auto bsz = mask_dim[0];
auto c = x_dim[2];
auto slim_seq_len = new_mask_dim[2];
ctx->SetOutputDim("SlimmedX", {bsz, slim_seq_len, c});
ctx->SetOutputDim("CLSInds", {bsz, slim_seq_len});
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_token_prune,
ops::FusedTokenPruneOp,
ops::FusedTokenPruneOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <limits>
#ifdef __NVCC__
#include <cub/cub.cuh>
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/fused_token_prune_op.cu.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
struct AttnMaskFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return b >= 0 ? a : 0;
}
};
__global__ void FillIndex(int64_t* indices, int num_raws, int num_cols) {
int num_threads = num_raws * num_cols;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; tid < num_threads; tid += stride) {
int col = tid % num_cols;
indices[tid] = (int64_t)col;
}
}
template <typename T>
__global__ void TakeAlongAxis(const T* src,
T* dst,
int64_t* indices,
int num_raws,
int src_num_cols,
int dst_num_cols,
int num_elements) {
int num_threads = num_raws * dst_num_cols;
int tid = threadIdx.x + blockIdx.x * blockDim.x;
int stride = blockDim.x * gridDim.x;
for (; tid < num_threads; tid += stride) {
int raw = tid / dst_num_cols;
int col = tid % dst_num_cols;
for (int i = 0; i < num_elements; ++i) {
dst[tid * num_elements + i] =
*(src + (raw * src_num_cols + indices[tid]) * num_elements + i);
}
}
}
template <typename T>
__global__ void MaximumFirst(T* mat, int num_raws, int num_cols, T max_value) {
int num_threads = num_raws;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
for (; tid < num_threads; tid += stride) {
mat[tid * num_cols] = max_value;
}
}
template <typename T>
class FusedTokenPruneOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto& dev_ctx = context.cuda_device_context();
// Inouts
const Tensor* attn = context.Input<Tensor>("Attn");
const Tensor* x = context.Input<Tensor>("X");
const Tensor* mask = context.Input<Tensor>("Mask");
const Tensor* new_mask = context.Input<Tensor>("NewMask");
// Input dims
auto attn_dims = attn->dims();
auto x_dims = x->dims();
auto new_mask_dims = new_mask->dims();
auto bsz = attn_dims[0];
auto num_heads = attn_dims[1];
auto max_seq_len = attn_dims[2];
auto c = x_dims[2];
int slimmed_x_len = new_mask_dims[2];
// Attrs
const bool keep_first_token = context.Attr<bool>("keep_first_token");
const bool keep_order = context.Attr<bool>("keep_order");
// Outputs
Tensor* out_slimmed_x = context.Output<Tensor>("SlimmedX");
Tensor* slimmed_indices = context.Output<Tensor>("CLSInds");
auto* out_slimmed_x_data =
out_slimmed_x->mutable_data<T>(context.GetPlace());
auto* slimmed_indices_data =
slimmed_indices->mutable_data<int64_t>(context.GetPlace());
// Intermediate variable
Tensor attn_tmp;
auto* attn_tmp_data =
attn_tmp.mutable_data<T>(attn_dims, context.GetPlace());
Tensor attn_accu;
auto* attn_accu_data =
attn_accu.mutable_data<T>({bsz, max_seq_len}, context.GetPlace());
Tensor attn_accu_indices;
auto* attn_accu_indices_data = attn_accu_indices.mutable_data<int64_t>(
{bsz, max_seq_len}, context.GetPlace());
Tensor sort_attn_accu;
auto* sort_attn_accu_data =
sort_attn_accu.mutable_data<T>({bsz, max_seq_len}, context.GetPlace());
Tensor sort_attn_accu_indices;
auto* sort_attn_accu_indices_data =
sort_attn_accu_indices.mutable_data<int64_t>({bsz, max_seq_len},
context.GetPlace());
Tensor temp_storage;
// 1. Filter attn by mask
std::vector<const Tensor*> ins;
std::vector<Tensor*> outs;
ins.emplace_back(attn);
ins.emplace_back(mask);
outs.emplace_back(&attn_tmp);
LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
dev_ctx, ins, &outs, -1, AttnMaskFunctor<T>());
// 2. Reduce sum
const std::vector<int64_t> reduce_dims{1, 2};
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(dev_ctx,
attn_tmp,
false,
reduce_dims,
false,
attn_accu.dtype(),
&attn_accu);
// 3. Prepare token indices
phi::backends::gpu::GpuLaunchConfig config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, bsz * max_seq_len);
FillIndex<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(attn_accu_indices_data, bsz, max_seq_len);
// 4. Sort token indices by attn
if (keep_first_token) {
T max = std::numeric_limits<T>::max();
config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, bsz);
MaximumFirst<T>
<<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(attn_accu_data, bsz, max_seq_len, max);
}
size_t temp_storage_bytes = -1;
int num_items = bsz * max_seq_len;
int num_segments = bsz;
cub::CountingInputIterator<int64_t> counting_iter(0);
cub::TransformInputIterator<int64_t,
SegmentOffsetIter,
cub::CountingInputIterator<int64_t>>
segment_offsets_t(counting_iter, SegmentOffsetIter(max_seq_len));
// Determine temporary device storage requirements
PADDLE_ENFORCE_GPU_SUCCESS(
cub::DeviceSegmentedRadixSort::SortPairsDescending(
nullptr,
temp_storage_bytes,
attn_accu_data,
sort_attn_accu_data,
attn_accu_indices_data,
sort_attn_accu_indices_data,
num_items,
num_segments,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
dev_ctx.stream()));
// Allocate temporary storage
int64_t temp_size = temp_storage_bytes;
auto* temp_storage_data =
temp_storage.mutable_data<uint8_t>({temp_size}, context.GetPlace());
// Run sorting operation
PADDLE_ENFORCE_GPU_SUCCESS(
cub::DeviceSegmentedRadixSort::SortPairsDescending(
temp_storage_data,
temp_storage_bytes,
attn_accu_data,
sort_attn_accu_data,
attn_accu_indices_data,
sort_attn_accu_indices_data,
num_items,
num_segments,
segment_offsets_t,
segment_offsets_t + 1,
0,
sizeof(T) * 8,
dev_ctx.stream()));
// 5. Slice
auto slimmed_indices_tmp =
phi::funcs::Slice<int64_t>(dev_ctx,
sort_attn_accu_indices,
{1} /*axes*/,
{0} /*starts*/,
{slimmed_x_len} /*ends*/);
if (keep_order) {
// 6. reorder
num_items = bsz * slimmed_x_len;
temp_storage_bytes = -1;
cub::TransformInputIterator<int64_t,
SegmentOffsetIter,
cub::CountingInputIterator<int64_t>>
segment_offsets_t2(counting_iter, SegmentOffsetIter(slimmed_x_len));
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys(
nullptr,
temp_storage_bytes,
static_cast<int64_t*>(slimmed_indices_tmp.data()),
static_cast<int64_t*>(slimmed_indices->data()),
num_items,
num_segments,
segment_offsets_t2,
segment_offsets_t2 + 1,
0,
sizeof(int64_t) * 8,
dev_ctx.stream()));
temp_size = temp_storage_bytes;
temp_storage.Resize({temp_size});
temp_storage_data =
temp_storage.mutable_data<uint8_t>(context.GetPlace());
PADDLE_ENFORCE_GPU_SUCCESS(cub::DeviceSegmentedRadixSort::SortKeys(
temp_storage_data,
temp_storage_bytes,
static_cast<int64_t*>(slimmed_indices_tmp.data()),
static_cast<int64_t*>(slimmed_indices->data()),
num_items,
num_segments,
segment_offsets_t2,
segment_offsets_t2 + 1,
0,
sizeof(int64_t) * 8,
dev_ctx.stream()));
} else {
framework::TensorCopy(
slimmed_indices_tmp, context.GetPlace(), slimmed_indices);
}
// 7. Get slimmed X by indices
config =
phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, bsz * slimmed_x_len);
TakeAlongAxis<T><<<config.block_per_grid,
config.thread_per_block,
0,
dev_ctx.stream()>>>(x->data<T>(),
out_slimmed_x_data,
slimmed_indices->data<int64_t>(),
bsz,
max_seq_len,
slimmed_x_len,
c);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(fused_token_prune,
ops::FusedTokenPruneOpCUDAKernel<float>,
ops::FusedTokenPruneOpCUDAKernel<double>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/phi/kernels/funcs/slice.h"
#include "paddle/phi/kernels/gpu/reduce.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
namespace paddle {
namespace operators {
HOSTDEVICE inline int CeilDivide(int n, int m) { return (n + m - 1) / m; }
inline int ComputeBlockSize(int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
}
// Iter for move to next row
struct SegmentOffsetIter {
EIGEN_DEVICE_FUNC
explicit SegmentOffsetIter(int num_cols) : num_cols_(num_cols) {}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE int operator()(int idx) const {
return idx * num_cols_;
}
int num_cols_;
};
} // namespace operators
} // namespace paddle
......@@ -28,6 +28,13 @@ if(NOT WITH_DISTRIBUTE)
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_c_allreduce")
endif()
if(WIN32)
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES
"test_trt_convert_fused_token_prune")
list(REMOVE_ITEM TEST_TRT_IR_PASSES "test_trt_convert_fused_token_prune")
list(REMOVE_ITEM TEST_TRT_CONVERTER "test_trt_convert_fused_token_prune")
endif()
# Only for cpu(mkl + openblas)
set(TEST_INFERENCE_CPU_UT "test_mul_lstm_fuse_pass" "test_mul_gru_fuse_pass")
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
class TrtConvertFusedTokenPruneTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self):
self.trt_param.workspace_size = 1073741824
def generate_attn_or_mask(attrs: List[Dict[str, Any]]):
return np.ones([4, 12, 64, 64]).astype(np.float32)
def generate_x(attrs: List[Dict[str, Any]]):
return np.random.random([4, 64, 76]).astype(np.float32)
def generate_new_mask(attrs: List[Dict[str, Any]]):
return np.random.random([4, 12, 32, 32]).astype(np.float32)
for keep_first_token in [True, False]:
for keep_order in [True, False]:
dics = [{
"keep_first_token": keep_first_token,
"keep_order": keep_order
}]
ops_config = [{
"op_type": "fused_token_prune",
"op_inputs": {
"Attn": ["attn"],
"X": ["x"],
"Mask": ["mask"],
"NewMask": ["new_mask"]
},
"op_outputs": {
"SlimmedX": ["slimmed_x"],
"CLSInds": ["cls_inds"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"attn":
TensorConfig(
data_gen=partial(generate_attn_or_mask, dics)),
"x":
TensorConfig(data_gen=partial(generate_x, dics)),
"mask":
TensorConfig(
data_gen=partial(generate_attn_or_mask, dics)),
"new_mask":
TensorConfig(data_gen=partial(generate_new_mask, dics))
},
outputs=["slimmed_x", "cls_inds"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {
"attn": [4, 12, 64, 64],
"x": [4, 64, 76],
"mask": [4, 12, 64, 64],
"new_mask": [4, 12, 32, 32]
}
self.dynamic_shape.max_input_shape = {
"attn": [4, 12, 64, 64],
"x": [4, 64, 76],
"mask": [4, 12, 64, 64],
"new_mask": [4, 12, 32, 32]
}
self.dynamic_shape.opt_input_shape = {
"attn": [4, 12, 64, 64],
"x": [4, 64, 76],
"mask": [4, 12, 64, 64],
"new_mask": [4, 12, 32, 32]
}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 6
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-5, 1e-5, 1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), (1e-5, 1e-5, 1e-5, 1e-5)
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from op_test import OpTest
from paddle.framework import core
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFusedTokenPruneOp(OpTest):
def setDtype(self):
self.dtype = np.float32
def setInouts(self):
attn = [[1, 2], [3, 4]]
attn = np.array(attn, dtype=self.dtype)
attn = np.expand_dims(attn, axis=0)
self.attn = np.expand_dims(
attn, axis=0) # [1,1,2,2] bsz = 1, nd_head=1, max_seq_len=2
mask = [[1, 1], [-1, -1]]
mask = np.array(mask, dtype=self.dtype)
mask = np.expand_dims(mask, axis=0)
self.mask = np.expand_dims(mask, axis=0) # same as attn
x = [[1, 2, 3], [4, 5, 6]]
x = np.array(x, dtype=self.dtype)
self.x = np.expand_dims(x,
axis=0) # [1, 2, 3] bsz = 1, max_seq_len=2, c=3
new_mask = [[1]]
new_mask = np.array(new_mask, dtype=self.dtype)
new_mask = np.expand_dims(new_mask, axis=0)
self.new_mask = np.expand_dims(new_mask, axis=0) #[1, 1, 1, 1]
out_slimmedx_py = [[[1, 2, 3]]]
self.out_slimmedx_py = np.array(out_slimmedx_py, dtype=self.dtype)
out_cls_inds_py = [[0]]
self.out_cls_inds_py = np.array(out_cls_inds_py, dtype='int64')
def setUp(self):
self.op_type = 'fused_token_prune'
self.setDtype()
self.setInouts()
self.inputs = {
'Attn': self.attn,
'Mask': self.mask,
'X': self.x,
'NewMask': self.new_mask
}
self.outputs = {
'SlimmedX': self.out_slimmedx_py,
'CLSInds': self.out_cls_inds_py
}
def test_check_output(self):
self.check_output_with_place(core.CUDAPlace(0))
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFusedTokenPruneOpFloat64(TestFusedTokenPruneOp):
def setDtype(self):
self.dtype = np.float64
@unittest.skipIf(not core.is_compiled_with_cuda(),
"core is not compiled with CUDA")
class TestFusedTokenPruneOp2(TestFusedTokenPruneOp):
def setInouts(self):
attn = [[[[1, 2, 3, 4], [4, 3, 2, 1], [5, 9, 5, 4], [9, 6, 5, 4]],
[[8, 5, 2, 0], [1, 0, 2, 3], [2, 2, 3, 2], [7, 4, 1, 8]]]]
self.attn = np.array(
attn,
dtype=self.dtype) # [1,2,4,4] bsz = 1, nd_head=2, max_seq_len=4
mask = [[[[-1, -1, -1, 1], [-1, -1, 1, 1], [-1, -1, 1, 1],
[-1, -1, 1, 1]],
[[-1, -1, 1, 1], [-1, -1, 1, 1], [-1, -1, 1, 1],
[-1, -1, 1, 1]]]]
self.mask = np.array(mask, dtype=self.dtype) # same as attn
x = [[[1.1, 1.1, 1.1], [2.2, 2.2, 2.2], [3.3, 3.3, 3.3],
[4.4, 4.4, 4.4]]]
self.x = np.array(
x, dtype=self.dtype) # [1, 4, 3] bsz = 1, max_seq_len=4, c=3
self.new_mask = np.random.rand(1, 2, 2,
2).astype(self.dtype) #[1, 2, 2, 2]
out_slimmedx_py = [[[1.1, 1.1, 1.1], [4.4, 4.4, 4.4]]] #[1, 2, 3]
self.out_slimmedx_py = np.array(out_slimmedx_py, dtype=self.dtype)
out_cls_inds_py = [[0, 3]]
self.out_cls_inds_py = np.array(out_cls_inds_py, dtype='int64')
if __name__ == "__main__":
unittest.main()
......@@ -233,6 +233,7 @@ STATIC_MODE_TESTING_LIST = [
'test_fused_elemwise_activation_op',
'test_fused_emb_seq_pool_op',
'test_fused_embedding_fc_lstm_op',
'test_fused_token_prune_op',
'test_fusion_gru_op',
'test_fusion_lstm_op',
'test_fusion_repeated_fc_relu_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册