提交 84103819 编写于 作者: Z zlsh80826

add 64/96/384 support

上级 d80ae5bc
...@@ -31,26 +31,63 @@ namespace plugin { ...@@ -31,26 +31,63 @@ namespace plugin {
to the mask with the bertQKV fused_multihead_attention format */ to the mask with the bertQKV fused_multihead_attention format */
constexpr size_t threadsPerCta128 = 2 * 2 * 32; constexpr size_t threadsPerCta128 = 2 * 2 * 32;
constexpr size_t threadsPerCta384 = 1 * 8 * 32;
constexpr size_t xmmasM128 = 4; constexpr size_t xmmasM128 = 4;
constexpr size_t xmmasM384 = 24;
constexpr size_t packedMaskSize64 = xmmasM128 * threadsPerCta128;
constexpr size_t packedMaskSize96 = xmmasM128 * threadsPerCta128;
constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128; constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384;
nvinfer1::DimsExprs ConvertMaskPluginDynamic::getOutputDimensions( nvinfer1::DimsExprs ConvertMaskPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) { nvinfer1::IExprBuilder& expr_builder) {
assert(output_index == 0); assert(output_index == 0);
constexpr int BDIM = 0;
constexpr int SDIM = 1;
if (type_ == nvinfer1::DataType::kHALF) { if (type_ == nvinfer1::DataType::kHALF) {
auto cms64 = expr_builder.constant(packedMaskSize64);
auto cms96 = expr_builder.constant(packedMaskSize96);
auto cms128 = expr_builder.constant(packedMaskSize128); auto cms128 = expr_builder.constant(packedMaskSize128);
auto cms384 = expr_builder.constant(packedMaskSize384);
auto c64 = expr_builder.constant(64);
auto c96 = expr_builder.constant(96);
auto c128 = expr_builder.constant(128);
auto c384 = expr_builder.constant(384);
auto is64 = expr_builder.operation(nvinfer1::DimensionOperation::kEQUAL,
*inputs[0].d[SDIM], *c64);
auto is96 = expr_builder.operation(nvinfer1::DimensionOperation::kEQUAL,
*inputs[0].d[SDIM], *c96);
auto is128 = expr_builder.operation(nvinfer1::DimensionOperation::kEQUAL,
*inputs[0].d[SDIM], *c128);
auto is384 = expr_builder.operation(nvinfer1::DimensionOperation::kEQUAL,
*inputs[0].d[SDIM], *c384);
auto sel64 = expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*is64, *cms64);
auto sel96 = expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*is96, *cms96);
auto sel128 = expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*is128, *cms128);
auto sel384 = expr_builder.operation(nvinfer1::DimensionOperation::kPROD,
*is384, *cms384);
auto maskSize1 = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*sel64, *sel96);
auto maskSize2 = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*sel384, *sel128);
auto maskSize = expr_builder.operation(nvinfer1::DimensionOperation::kSUM,
*maskSize1, *maskSize2);
auto fp16maskSize = auto fp16maskSize =
expr_builder.operation(nvinfer1::DimensionOperation::kPROD, *cms128, expr_builder.operation(nvinfer1::DimensionOperation::kPROD, *maskSize,
*expr_builder.constant(2)); *expr_builder.constant(2));
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
ret.nbDims = 2; ret.nbDims = 2;
ret.d[0] = inputs[0].d[0]; ret.d[0] = inputs[0].d[BDIM];
ret.d[1] = fp16maskSize; ret.d[1] = fp16maskSize;
return ret; return ret;
} }
nvinfer1::DimsExprs ret; nvinfer1::DimsExprs ret;
...@@ -187,7 +224,7 @@ int ConvertMaskPluginDynamic::enqueue( ...@@ -187,7 +224,7 @@ int ConvertMaskPluginDynamic::enqueue(
int batch = input_dims.d[0]; int batch = input_dims.d[0];
int seq_len = input_dims.d[1]; int seq_len = input_dims.d[1];
assert(seq_len == 128); // assert(seq_len == 64 || seq_len == 96 || seq_len == 128 || seq_len == 384);
if (type_ == nvinfer1::DataType::kFLOAT) { if (type_ == nvinfer1::DataType::kFLOAT) {
IMaskPreprocess<<<batch, seq_len, 0, stream>>>( IMaskPreprocess<<<batch, seq_len, 0, stream>>>(
...@@ -204,11 +241,24 @@ int ConvertMaskPluginDynamic::enqueue( ...@@ -204,11 +241,24 @@ int ConvertMaskPluginDynamic::enqueue(
static_cast<const half*>(inputs[0]), inputMaskSB, seq_len, batch); static_cast<const half*>(inputs[0]), inputMaskSB, seq_len, batch);
} }
size_t warps_m = 0, warps_n = 0, warps_k = 1; size_t warps_m = 0, warps_n = 0, warps_k = 1;
if (seq_len == 128) { if (seq_len == 64 || seq_len == 96 || seq_len == 128) {
warps_m = 2; warps_m = 2;
warps_n = 2; warps_n = 2;
} else if (seq_len == 384) {
warps_m = 1;
warps_n = 8;
} else {
assert(false);
} }
/*
int* buf_h = (int*)malloc(batch * seq_len * sizeof(int));
cudaMemcpy(buf_h, inputMaskSB, batch * seq_len * sizeof(int),
cudaMemcpyDeviceToHost);
for (int i = 0; i < batch*seq_len; ++ i) {
std::cerr << buf_h[i] << " ";
}
std::cerr << std::endl;
*/
convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB, convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB,
static_cast<uint32_t*>(outputs[0]), stream); static_cast<uint32_t*>(outputs[0]), stream);
cudaFree(inputMaskSB); cudaFree(inputMaskSB);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册