提交 d4dcc80d 编写于 作者: Z zlsh80826

MHA fp16

上级 03acac2b
......@@ -138,7 +138,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
*reshape_layer->getOutput(0),
nvinfer1::ReduceOperation::kMAX, 1, false);
*/
auto imask_tensor = engine_->GetITensor("imask_tensor");
// auto imask_tensor = engine_->GetITensor("imask_tensor");
auto imask_tensor = engine_->GetITensor("fused_mha_mask");
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "1");
......
......@@ -173,8 +173,6 @@ class OpConverter {
"optim_input_shape should be same."));
}
}
std::cerr << "Declare input: " << input << std::endl;
if (input.find("stack_0.tmp_0") != std::string::npos) continue;
engine->DeclareInput(
input, FluidDataType2TRT(
var->Proto()->type().lod_tensor().tensor().data_type()),
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
namespace paddle {
namespace inference {
......@@ -26,6 +27,7 @@ class ScaleOpConverter : public OpConverter {
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid scale op to tensorrt mul layer without bias";
std::cerr << "Scale converter" << std::endl;
framework::OpDesc op_desc(op, nullptr);
// Declare inputs
......@@ -64,6 +66,12 @@ class ScaleOpConverter : public OpConverter {
platform::errors::Fatal(
"Paddle-TRT scale mode only support dimension >= 3"));
plugin::ConvertMaskPluginDynamic* plugin =
new plugin::ConvertMaskPluginDynamic();
auto convert_mask_layer = engine_->AddPluginV2(&input, 1, plugin);
convert_mask_layer->setName("convert_mask_layer");
engine_->SetITensor("fused_mha_mask", convert_mask_layer->getOutput(0));
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
......
......@@ -183,8 +183,6 @@ class TRTConvertValidation {
std::vector<void*> buffers(num_bindings);
for (const std::string& name : input_output_names) {
// std::cerr << "Binding name: " << name << std::endl;
if (name.find("stack_0.tmp_0") != std::string::npos) continue;
auto* var = scope_.FindVar(name);
auto* tensor = var->GetMutable<framework::LoDTensor>();
const int bind_index = engine_->engine()->getBindingIndex(name.c_str());
......
......@@ -2,7 +2,7 @@ nv_library(tensorrt_plugin
SRCS trt_plugin.cc split_op_plugin.cu elementwise_op_plugin.cu
prelu_op_plugin.cu trt_plugin_factory.cc gelu_op_plugin.cu
pool_op_plugin.cu swish_op_plugin.cu layer_norm_op_plugin.cu
cast_int_plugin.cu stack_op_plugin.cu
cast_int_plugin.cu stack_op_plugin.cu convert_mask_plugin.cu
instance_norm_op_plugin.cu emb_eltwise_layernorm_plugin.cu
qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu hard_swish_op_plugin.cu
DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cassert>
#include <cstring>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/convert_mask_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
// Dynamic Plugin below.
#if IS_TRT_VERSION_GE(6000)
/* This plugin currently converts the matmul output [B, S, S]
to the mask with the bertQKV fused_multihead_attention format */
constexpr size_t threadsPerCta128 = 2 * 2 * 32;
constexpr size_t xmmasM128 = 4;
constexpr size_t packedMaskSize128 = xmmasM128 * threadsPerCta128;
nvinfer1::DimsExprs ConvertMaskPluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) {
auto cms128 = expr_builder.constant(packedMaskSize128);
auto fp16maskSize = expr_builder.operation(
nvinfer1::DimensionOperation::kPROD, *cms128, *expr_builder.constant(2));
nvinfer1::DimsExprs ret;
ret.nbDims = 2;
ret.d[0] = inputs[0].d[0];
ret.d[1] = fp16maskSize;
return ret;
}
bool ConvertMaskPluginDynamic::supportsFormatCombination(
int pos, const nvinfer1::PluginTensorDesc* in_out, int nb_inputs,
int nb_outputs) {
const nvinfer1::PluginTensorDesc& desc = in_out[pos];
/* input: [B, S, S] */
/* output: [B, 2*maskSize] */
assert(nb_inputs == 1);
assert(nb_outputs == 1);
if (pos == 0) {
std::cerr << "desc.type: " << static_cast<int>(desc.type) << " "
<< desc.dims.nbDims << std::endl;
return ((desc.type == nvinfer1::DataType::kFLOAT ||
desc.type == nvinfer1::DataType::kHALF) &&
desc.dims.nbDims == 3);
}
std::cerr << "output.type: " << static_cast<int>(desc.type) << " "
<< desc.dims.nbDims << std::endl;
// return desc.type == nvinfer1::DataType::kHALF;
return true;
}
nvinfer1::DataType ConvertMaskPluginDynamic::getOutputDataType(
int index, const nvinfer1::DataType* input_types, int nb_inputs) const {
PADDLE_ENFORCE_EQ(index, 0,
platform::errors::InvalidArgument(
"The convert mask plugin only has one input, so the "
"index value should be 0, but get %d.",
index));
return nvinfer1::DataType::kHALF;
}
template <typename T>
__global__ void CastToIntAndReduce(const T* input, int* output, int seq_len,
int batch) {
int bid = blockIdx.x;
int sid = threadIdx.x;
output[sid * batch + bid] =
static_cast<int>(input[bid * seq_len * seq_len + sid]);
}
__global__ void fillSBSMaskKernel(const uint32_t warps_m,
const uint32_t warps_n, const uint32_t S,
const int* inputMaskSB,
uint32_t* inputMaskX) {
extern __shared__ int shm_mask[]; // S mask elements of this batch
const size_t xmmas_n = (S + 16 * warps_n - 1) / (16 * warps_n);
const uint32_t threads_per_cta = blockDim.x;
const uint32_t xmmas_m = gridDim.x;
const uint32_t B = gridDim.y;
const uint32_t mi = blockIdx.x;
const uint32_t bi = blockIdx.y;
const uint32_t tidx = threadIdx.x;
const size_t warp = tidx / 32;
const size_t warp_m = warp % warps_m;
const size_t warp_n = warp / warps_m;
const size_t lane = tidx % 32;
const size_t col = warp_n * 16 + lane % 4 * 2;
// load the mask corresponding to one batch
for (uint32_t si = tidx; si < S; si += threads_per_cta) {
// not coalesced to conform to current input format: SxB
shm_mask[si] = inputMaskSB[si * B + bi];
}
__syncthreads();
uint32_t mask = 0u;
for (size_t ni = 0; ni < xmmas_n; ++ni) {
const int offset = ni * 16 * warps_n + col;
mask |= (shm_mask[offset + 0] == 1.f ? 1u : 0u) << (8 * ni + 0);
mask |= (shm_mask[offset + 1] == 1.f ? 1u : 0u) << (8 * ni + 1);
mask |= (shm_mask[offset + 0] == 1.f ? 1u : 0u) << (8 * ni + 2);
mask |= (shm_mask[offset + 1] == 1.f ? 1u : 0u) << (8 * ni + 3);
mask |= (shm_mask[offset + 8] == 1.f ? 1u : 0u) << (8 * ni + 4);
mask |= (shm_mask[offset + 9] == 1.f ? 1u : 0u) << (8 * ni + 5);
mask |= (shm_mask[offset + 8] == 1.f ? 1u : 0u) << (8 * ni + 6);
mask |= (shm_mask[offset + 9] == 1.f ? 1u : 0u) << (8 * ni + 7);
}
inputMaskX[(bi * xmmas_m + mi) * threads_per_cta + tidx] = mask;
}
void convertMask(const uint32_t S, const uint32_t B, const uint32_t warps_m,
const uint32_t warps_n, const uint32_t warps_k,
const int* inputMaskSB, uint32_t* inputMaskX,
cudaStream_t stream) {
const size_t xmmas_m = (S + 16 * warps_m - 1) / (16 * warps_m);
const size_t threads_per_cta = warps_m * warps_n * warps_k * 32;
dim3 grid(xmmas_m, B);
fillSBSMaskKernel<<<grid, threads_per_cta, S * sizeof(int), stream>>>(
warps_m, warps_n, S, inputMaskSB, inputMaskX);
}
int ConvertMaskPluginDynamic::enqueue(
const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc, const void* const* inputs,
void* const* outputs, void* workspace, cudaStream_t stream) {
auto input_dims = input_desc[0].dims;
auto output_dims = output_desc[0].dims;
size_t num_elements = ProductDim(input_dims);
size_t out_num_elements = ProductDim(output_dims);
int batch = input_dims.d[0];
int seq_len = input_dims.d[1];
assert(num_elements == out_num_elements * seq_len);
assert(seq_len <= 1024);
assert(output_desc.type == nvinfer1::DataType::kHALF);
// temp use, should remove
int* inputMaskSB;
cudaMalloc(&inputMaskSB, batch * seq_len * sizeof(int));
if (input_desc[0].type == nvinfer1::DataType::kFLOAT) {
CastToIntAndReduce<float><<<batch, seq_len, 0, stream>>>(
static_cast<const float*>(inputs[0]), inputMaskSB, seq_len, batch);
} else {
CastToIntAndReduce<half><<<batch, seq_len, 0, stream>>>(
static_cast<const half*>(inputs[0]), inputMaskSB, seq_len, batch);
}
assert(seq_len == 128);
size_t warps_m = 0, warps_n = 0, warps_k = 1;
if (seq_len == 128) {
warps_m = 2;
warps_n = 2;
}
convertMask(seq_len, batch, warps_m, warps_n, warps_k, inputMaskSB,
static_cast<uint32_t*>(outputs[0]), stream);
cudaFree(inputMaskSB);
return cudaGetLastError() != cudaSuccess;
}
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdio.h>
#include <cassert>
#include <string>
#include <vector>
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
#if IS_TRT_VERSION_GE(6000)
class ConvertMaskPluginDynamic : public DynamicPluginTensorRT {
public:
ConvertMaskPluginDynamic() {}
ConvertMaskPluginDynamic(void const* serial_data, size_t serial_length) {}
~ConvertMaskPluginDynamic() {}
nvinfer1::IPluginV2DynamicExt* clone() const override {
return new ConvertMaskPluginDynamic();
}
const char* getPluginType() const override { return "convert_mask_plugin"; }
int getNbOutputs() const override { return 1; }
int initialize() override { return 0; }
size_t getSerializationSize() const override { return 0; }
void serialize(void* buffer) const override {}
nvinfer1::DimsExprs getOutputDimensions(
int output_index, const nvinfer1::DimsExprs* inputs, int nb_inputs,
nvinfer1::IExprBuilder& expr_builder) override;
bool supportsFormatCombination(int pos,
const nvinfer1::PluginTensorDesc* in_out,
int nb_inputs, int nb_outputs) override;
void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in,
int nb_inputs,
const nvinfer1::DynamicPluginTensorDesc* out,
int nb_outputs) override {}
size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int nb_inputs,
const nvinfer1::PluginTensorDesc* outputs,
int nb_outputs) const override {
return 0;
}
int enqueue(const nvinfer1::PluginTensorDesc* input_desc,
const nvinfer1::PluginTensorDesc* output_desc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) override;
nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* input_types,
int nb_inputs) const override;
void destroy() override { delete this; }
};
class ConvertMaskPluginV2Creator : public nvinfer1::IPluginCreator {
public:
ConvertMaskPluginV2Creator() {}
const char* getPluginName() const override { return "convert_mask_plugin"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new ConvertMaskPluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(ConvertMaskPluginV2Creator);
#endif
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -227,8 +227,6 @@ class TensorRTEngineOp : public framework::OperatorBase {
// Bind input tensor to TRT.
for (const auto &x : Inputs("Xs")) {
if (param_names_.count(x)) continue;
// std::cerr << "runTRT name: " << x << std::endl;
if (x.find("stack_0.tmp_0") != std::string::npos) continue;
// convert input and copy to TRT engine's buffer
auto &t =
inference::analysis::GetFromScope<framework::LoDTensor>(scope, x);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册