未验证 提交 f79be656 编写于 作者: F feng_shuai 提交者: GitHub

padding the length of input for vit_attention (#45506)

* vit_384_opt

* just support trt8

* padding + unpadding

* fix:unit test

* refactor:padding

* fix: change the position of round_up

* refactor: delete workspace
上级 b93b1e34
...@@ -536,8 +536,6 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -536,8 +536,6 @@ class MultiheadMatMulOpConverter : public OpConverter {
"but it's (%d) now.", "but it's (%d) now.",
input->getDimensions().nbDims)); input->getDimensions().nbDims));
// transpose weight_data from m * n to n * m // transpose weight_data from m * n to n * m
auto* input_bias_qk =
engine_->GetITensor(op_desc.Input("BiasQK").front());
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
...@@ -615,6 +613,17 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -615,6 +613,17 @@ class MultiheadMatMulOpConverter : public OpConverter {
std::vector<nvinfer1::ITensor*> plugin_inputs; std::vector<nvinfer1::ITensor*> plugin_inputs;
plugin_inputs.push_back(fc_layer->getOutput(0)); plugin_inputs.push_back(fc_layer->getOutput(0));
auto inputs = op_desc.Inputs();
bool hasBiasQK =
(inputs.find("BiasQK") == inputs.end()) ? false : true;
nvinfer1::ITensor* input_bias_qk = nullptr;
if (hasBiasQK) {
input_bias_qk =
engine_->GetITensor(op_desc.Input("BiasQK").front());
} else {
// fake input will be updated in qkv_plugin
input_bias_qk = fc_layer->getOutput(0);
}
plugin_inputs.push_back(input_bias_qk); plugin_inputs.push_back(input_bias_qk);
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
......
...@@ -50,6 +50,71 @@ __global__ void transpose(T *src, ...@@ -50,6 +50,71 @@ __global__ void transpose(T *src,
threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x]; threadIdx.x] = src[blockIdx.x * size_per_head + threadIdx.x];
} }
inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT(
multiple,
0,
platform::errors::InvalidArgument(
"multiple should be a positive number,but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple;
}
template <typename T>
__global__ void reset_qk_bias(T *input, int real_seq_len, int seq_len) {
if (threadIdx.x < seq_len) {
int id = threadIdx.x + blockIdx.x * seq_len;
input[id] = threadIdx.x >= real_seq_len ? (T)-1e20f : (T)0.0f;
}
}
template <typename T>
__global__ void transpose_qkv_padding(
const T *src, // (Batch, real_seq_len, 3 , head_num * size_per_head)
T *dst, // (3 * batch * head_num * seq_len * size_per_head)
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int real_seq_len) {
// const dim3 grid(seq_len, batch, 3);
// const dim3 block(head_size, head_num, 1);
int qkv_id = blockIdx.z;
int batch_id = blockIdx.y;
int seq_id = blockIdx.x;
int head_id = threadIdx.y;
const int dst_offset =
qkv_id * batch_size * head_num * seq_len * size_per_head +
batch_id * head_num * seq_len * size_per_head +
head_id * seq_len * size_per_head + seq_id * size_per_head;
const int src_offset =
batch_id * real_seq_len * 3 * head_num * size_per_head +
seq_id * 3 * head_num * size_per_head +
qkv_id * head_num * size_per_head + head_id * size_per_head;
if (seq_id < real_seq_len) {
dst[threadIdx.x + dst_offset] = src[threadIdx.x + src_offset];
} else if (seq_id < seq_len) {
dst[threadIdx.x + dst_offset] = 0;
}
}
template <typename T>
__global__ void transpose_qkv_unpadding(const T *src,
T *dst,
const int batch_size,
const int seq_len,
const int head_num,
const int size_per_head,
const int real_seq_len) {
int batch_id = blockIdx.x / (head_num * real_seq_len);
int seq_id = blockIdx.x % real_seq_len;
int head_id = blockIdx.x % (head_num * real_seq_len) / real_seq_len;
dst[batch_id * head_num * real_seq_len * size_per_head +
seq_id * head_num * size_per_head + head_id * size_per_head +
threadIdx.x] = src[batch_id * head_num * seq_len * size_per_head +
head_id * seq_len * size_per_head +
seq_id * size_per_head + threadIdx.x];
}
template <typename T> template <typename T>
__global__ void TransposeQkvKernel(const int H, const T *input, T *output) { __global__ void TransposeQkvKernel(const int H, const T *input, T *output) {
// Input: BxSx3xNxH // Input: BxSx3xNxH
...@@ -209,6 +274,48 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions( ...@@ -209,6 +274,48 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions(
return ret; return ret;
} }
void QkvToContextPluginDynamic::configurePlugin(
const nvinfer1::DynamicPluginTensorDesc *in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc *out,
int nb_outputs) TRT_NOEXCEPT {
auto input_dims = in[0].desc.dims;
int batch = input_dims.d[0];
int real_seq_len = input_dims.d[1];
int seq_len = round_up(real_seq_len, 8);
if (batch != -1 && real_seq_len != -1) {
int device_id = 0;
cudaGetDevice(&device_id);
auto *device_ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get(
platform::CUDAPlace(device_id)));
const phi::GPUContext &dev_ctx = *device_ctx;
auto stream = dev_ctx.stream();
tensor_.Resize({batch, seq_len, seq_len, head_number_});
int blocks = batch * head_number_ * seq_len;
if (in[0].desc.type == nvinfer1::DataType::kHALF) {
mask_half_ = reinterpret_cast<half *>(
tensor_.mutable_data<int16_t>(platform::CUDAPlace(device_id)));
reset_qk_bias<<<blocks, 1024, 0, stream>>>(
mask_half_, real_seq_len, seq_len);
} else if (in[0].desc.type == nvinfer1::DataType::kFLOAT) {
fake_qk_bias_ = reinterpret_cast<float *>(
tensor_.mutable_data<int32_t>(platform::CUDAPlace(device_id)));
long size = sizeof(int32_t) * batch * seq_len * seq_len * head_number_;
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
hipMemsetAsync(fake_qk_bias_, 0, size, dev_ctx.stream()));
#else
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(fake_qk_bias_, 0, size, dev_ctx.stream()));
#endif
} else {
PADDLE_THROW(platform::errors::Fatal(
"The QKV TRT Plugin's input type should be float or half."));
}
}
}
bool QkvToContextPluginDynamic::supportsFormatCombination( bool QkvToContextPluginDynamic::supportsFormatCombination(
int pos, int pos,
const nvinfer1::PluginTensorDesc *in_out, const nvinfer1::PluginTensorDesc *in_out,
...@@ -277,15 +384,6 @@ __global__ void apply_scale(T *data, T scale, int n) { ...@@ -277,15 +384,6 @@ __global__ void apply_scale(T *data, T scale, int n) {
#endif #endif
} }
inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT(
multiple,
0,
platform::errors::InvalidArgument(
"multiple should be a positive number,but it's (%d)", multiple));
return ((seq_len + multiple - 1) / multiple) * multiple;
}
template <typename T> template <typename T>
__global__ void broadcast(const T *src, __global__ void broadcast(const T *src,
T *dst, T *dst,
...@@ -342,6 +440,10 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -342,6 +440,10 @@ int QkvToContextPluginDynamic::enqueue(
head_number_); head_number_);
qk_bias = temp_qk_bias; qk_bias = temp_qk_bias;
} }
// fake qk_bias
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
qk_bias = fake_qk_bias_;
}
const float *input1_data = static_cast<const float *>(qk_bias); const float *input1_data = static_cast<const float *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH. // BxSx3xNxH => tptr: 3xBxNxSxH.
TransposeQKV( TransposeQKV(
...@@ -373,6 +475,16 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -373,6 +475,16 @@ int QkvToContextPluginDynamic::enqueue(
} else if (input_type == nvinfer1::DataType::kHALF) { } else if (input_type == nvinfer1::DataType::kHALF) {
#ifdef TRT_PLUGIN_FP16_AVALIABLE #ifdef TRT_PLUGIN_FP16_AVALIABLE
VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp16"; VLOG(1) << "TRT Plugin DataType selected. QkvToContext-->fp16";
int real_seq_len = seq_len;
int need_padding = false;
// fake qk_bias
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
seq_len = round_up(real_seq_len, 8);
scratch_size = batch * head_number_ * seq_len * seq_len * 1;
input_num = batch * seq_len * 3 * head_number_ * head_size_;
multihead_temp_tensor.Resize({scratch_size + input_num});
need_padding = (real_seq_len != seq_len) ? true : false;
}
auto *multihead_temp_data = auto *multihead_temp_data =
multihead_temp_tensor.mutable_data<int16_t>( // NOLINT multihead_temp_tensor.mutable_data<int16_t>( // NOLINT
platform::CUDAPlace(device_id)); platform::CUDAPlace(device_id));
...@@ -398,10 +510,27 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -398,10 +510,27 @@ int QkvToContextPluginDynamic::enqueue(
head_number_); head_number_);
qk_bias = temp_qk_bias; qk_bias = temp_qk_bias;
} }
// padding: mask_half_ = [0,0,...-1e20f,-1e20f]
// no_padding: mask_half_ = [0,.....0,.........,0]
if (ProductDim(input_desc[1].dims) == ProductDim(input_desc[0].dims)) {
qk_bias = mask_half_;
}
const half *input1_data = static_cast<const half *>(qk_bias); const half *input1_data = static_cast<const half *>(qk_bias);
// BxSx3xNxH => tptr: 3xBxNxSxH. // BxSx3xNxH => tptr: 3xBxNxSxH.
if (need_padding) {
dim3 grid_p(seq_len, batch, 3);
dim3 block_p(head_size_, head_number_, 1);
transpose_qkv_padding<<<grid_p, block_p, 0, stream>>>(input0_data,
tptr,
batch,
seq_len,
head_number_,
head_size_,
real_seq_len);
} else {
TransposeQKV( TransposeQKV(
batch, seq_len, head_size_, head_number_, input0_data, tptr, stream); batch, seq_len, head_size_, head_number_, input0_data, tptr, stream);
}
auto *device_ctx = static_cast<phi::GPUContext *>( auto *device_ctx = static_cast<phi::GPUContext *>(
platform::DeviceContextPool::Instance().Get( platform::DeviceContextPool::Instance().Get(
...@@ -430,8 +559,15 @@ int QkvToContextPluginDynamic::enqueue( ...@@ -430,8 +559,15 @@ int QkvToContextPluginDynamic::enqueue(
int grid = batch * head_number_ * seq_len; int grid = batch * head_number_ * seq_len;
int block = head_size_; int block = head_size_;
half *output = static_cast<half *>(outputs[0]); half *output = static_cast<half *>(outputs[0]);
if (need_padding) {
int grid_u = batch * head_number_ * real_seq_len;
int block_u = head_size_;
transpose_qkv_unpadding<half><<<grid_u, block_u, 0, stream>>>(
tptr, output, batch, seq_len, head_number_, head_size_, real_seq_len);
} else {
transpose<half><<<grid, block, 0, stream>>>( transpose<half><<<grid, block, 0, stream>>>(
tptr, output, batch, seq_len, head_number_, head_size_); tptr, output, batch, seq_len, head_number_, head_size_);
}
#else #else
PADDLE_THROW(platform::errors::Fatal( PADDLE_THROW(platform::errors::Fatal(
"The Ernie(Bert) TensorRT Plugin should be " "The Ernie(Bert) TensorRT Plugin should be "
......
...@@ -97,7 +97,7 @@ class QkvToContextPluginDynamic : public DynamicPluginTensorRT { ...@@ -97,7 +97,7 @@ class QkvToContextPluginDynamic : public DynamicPluginTensorRT {
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs, int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out, const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) TRT_NOEXCEPT override {} int nb_outputs) TRT_NOEXCEPT override;
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs, size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs, int nb_inputs,
...@@ -124,6 +124,9 @@ class QkvToContextPluginDynamic : public DynamicPluginTensorRT { ...@@ -124,6 +124,9 @@ class QkvToContextPluginDynamic : public DynamicPluginTensorRT {
int head_number_; int head_number_;
int head_size_; int head_size_;
float scale_; float scale_;
framework::Tensor tensor_;
half* mask_half_;
float* fake_qk_bias_;
}; };
class QkvToContextPluginDynamicCreator : public nvinfer1::IPluginCreator { class QkvToContextPluginDynamicCreator : public nvinfer1::IPluginCreator {
......
...@@ -424,6 +424,7 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -424,6 +424,7 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest):
# for static_shape # for static_shape
clear_dynamic_shape() clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 4), (1e-5, 1e-5) yield self.create_inference_config(), (1, 4), (1e-5, 1e-5)
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 4), (1e-5, 1e-5) yield self.create_inference_config(), (1, 4), (1e-5, 1e-5)
...@@ -431,6 +432,7 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -431,6 +432,7 @@ class TrtConvertMultiHeadMatmulTest(TrtLayerAutoScanTest):
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape(attrs) generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
self.trt_param.workspace_size = 2013265920
yield self.create_inference_config(), (1, 3), (1e-5, 1e-4) yield self.create_inference_config(), (1, 3), (1e-5, 1e-4)
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), (1e-5, 1e-5) yield self.create_inference_config(), (1, 3), (1e-5, 1e-5)
...@@ -855,8 +857,8 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -855,8 +857,8 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
def sample_program_configs(self): def sample_program_configs(self):
def generate_input1(): def generate_input1(batch, length):
return np.zeros((1, 256, 768), dtype=np.float32) return np.zeros((batch, length, 768), dtype=np.float32)
def generate_weight1(): def generate_weight1():
return np.random.rand(768, 2304).astype(np.float32) return np.random.rand(768, 2304).astype(np.float32)
...@@ -864,6 +866,10 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -864,6 +866,10 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
def generate_weight2(): def generate_weight2():
return np.random.rand(2304).astype(np.float32) return np.random.rand(2304).astype(np.float32)
for batch in [2, 4]:
self.batch = batch
for length in [64, 384]:
self.length = length
ops_config = [{ ops_config = [{
"op_type": "matmul_v2", "op_type": "matmul_v2",
"op_inputs": { "op_inputs": {
...@@ -902,7 +908,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -902,7 +908,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
"XShape": ["reshape1_output_xshape"] "XShape": ["reshape1_output_xshape"]
}, },
"op_attrs": { "op_attrs": {
"shape": [-1, 256, 3, 12, 64] "shape": [-1, self.length, 3, 12, 64]
} }
}, { }, {
"op_type": "transpose2", "op_type": "transpose2",
...@@ -1048,7 +1054,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -1048,7 +1054,7 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
"XShape": ["reshape2_output_xshape"] "XShape": ["reshape2_output_xshape"]
}, },
"op_attrs": { "op_attrs": {
"shape": [-1, 256, 768] "shape": [-1, self.length, 768]
} }
}] }]
...@@ -1063,7 +1069,9 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -1063,7 +1069,9 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
TensorConfig(data_gen=partial(generate_weight2)) TensorConfig(data_gen=partial(generate_weight2))
}, },
inputs={ inputs={
"input_data1": TensorConfig(data_gen=partial(generate_input1)) "input_data1":
TensorConfig(
data_gen=partial(generate_input1, batch, length))
}, },
outputs=["reshape2_output"]) outputs=["reshape2_output"])
...@@ -1105,6 +1113,9 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest): ...@@ -1105,6 +1113,9 @@ class TrtConvertVitToMultiHeadMatmulTest(TrtLayerAutoScanTest):
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3, yield self.create_inference_config(), generate_trt_nodes_num(), (1e-3,
1e-3) 1e-3)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(), (1e-5,
1e-5)
def add_skip_trt_case(self): def add_skip_trt_case(self):
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册