diff --git a/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu index 3a02343033c0d90d19614211bed169ec998503c6..361b929681a83bf7153bd8070646475dc9160f15 100644 --- a/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.cu @@ -31,26 +31,63 @@ namespace plugin { to the mask with the bertQKV fused_multihead_attention format */ constexpr size_t threadsPerCta128 = 2 * 2 * 32; +constexpr size_t threadsPerCta384 = 1 * 8 * 32; constexpr size_t xmmasM128 = 4; +constexpr size_t xmmasM384 = 24; +constexpr size_t packedMaskSize64 = xmmasM128 * threadsPerCta128; +constexpr size_t packedMaskSize96 = xmmasM128 * threadsPerCta128; constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128; +constexpr size_t packedMaskSize384 = xmmasM384 * threadsPerCta384; nvinfer1::DimsExprs ConvertMaskPluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs, nvinfer1::IExprBuilder& expr_builder) { assert(output_index == 0); + constexpr int BDIM = 0; + constexpr int SDIM = 1; + if (type_ == nvinfer1::DataType::kHALF) { + auto cms64 = expr_builder.constant(packedMaskSize64); + auto cms96 = expr_builder.constant(packedMaskSize96); auto cms128 = expr_builder.constant(packedMaskSize128); + auto cms384 = expr_builder.constant(packedMaskSize384); + auto c64 = expr_builder.constant(64); + auto c96 = expr_builder.constant(96); + auto c128 = expr_builder.constant(128); + auto c384 = expr_builder.constant(384); + + auto is64 = expr_builder.operation(nvinfer1::DimensionOperation::kEQUAL, + *inputs[0].d[SDIM], *c64); + auto is96 = expr_builder.operation(nvinfer1::DimensionOperation::kEQUAL, + *inputs[0].d[SDIM], *c96); + auto is128 = expr_builder.operation(nvinfer1::DimensionOperation::kEQUAL, + *inputs[0].d[SDIM], *c128); + auto is384 = expr_builder.operation(nvinfer1::DimensionOperation::kEQUAL, + *inputs[0].d[SDIM], *c384); + auto sel64 = expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *is64, *cms64); + auto sel96 = expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *is96, *cms96); + auto sel128 = expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *is128, *cms128); + auto sel384 = expr_builder.operation(nvinfer1::DimensionOperation::kPROD, + *is384, *cms384); + auto maskSize1 = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *sel64, *sel96); + auto maskSize2 = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *sel384, *sel128); + auto maskSize = expr_builder.operation(nvinfer1::DimensionOperation::kSUM, + *maskSize1, *maskSize2); auto fp16maskSize = - expr_builder.operation(nvinfer1::DimensionOperation::kPROD, *cms128, + expr_builder.operation(nvinfer1::DimensionOperation::kPROD, *maskSize, *expr_builder.constant(2)); nvinfer1::DimsExprs ret; ret.nbDims = 2; - ret.d[0] = inputs[0].d[0]; + ret.d[0] = inputs[0].d[BDIM]; ret.d[1] = fp16maskSize; - return ret; } nvinfer1::DimsExprs ret; @@ -187,7 +224,7 @@ int ConvertMaskPluginDynamic::enqueue( int batch = input_dims.d[0]; int seq_len = input_dims.d[1]; - assert(seq_len == 128); + // assert(seq_len == 64 || seq_len == 96 || seq_len == 128 || seq_len == 384); if (type_ == nvinfer1::DataType::kFLOAT) { IMaskPreprocess<<>>( @@ -204,11 +241,24 @@ int ConvertMaskPluginDynamic::enqueue( static_cast(inputs[0]), inputMaskSB, seq_len, batch); } size_t warps_m = 0, warps_n = 0, warps_k = 1; - if (seq_len == 128) { + if (seq_len == 64 || seq_len == 96 || seq_len == 128) { warps_m = 2; warps_n = 2; + } else if (seq_len == 384) { + warps_m = 1; + warps_n = 8; + } else { + assert(false); } - + /* + int* buf_h = (int*)malloc(batch * seq_len * sizeof(int)); + cudaMemcpy(buf_h, inputMaskSB, batch * seq_len * sizeof(int), + cudaMemcpyDeviceToHost); + for (int i = 0; i < batch*seq_len; ++ i) { + std::cerr << buf_h[i] << " "; + } + std::cerr << std::endl; + */ convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB, static_cast(outputs[0]), stream); cudaFree(inputMaskSB);