未验证 提交 2a9c590b 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle Inference] add lookup_table op_convert, add lookup_table plugin (#46613)

* add lookup_table op_convert, add lookup_table plugin
上级 19746835
......@@ -2200,6 +2200,7 @@ USE_TRT_CONVERTER(fused_token_prune)
USE_TRT_CONVERTER(layernorm_shift_partition)
USE_TRT_CONVERTER(generic_plugin_creater)
USE_TRT_CONVERTER(custom_plugin_creater)
USE_TRT_CONVERTER(lookup_table)
#if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000)
USE_TRT_CONVERTER(sparse_fc)
USE_TRT_CONVERTER(sparse_multihead_matmul)
......
......@@ -76,7 +76,8 @@ list(
fill_constant_op.cc
fused_token_prune_op.cc
layernorm_shift_partition_op.cc
generic_and_custom_plugin_creater.cc)
generic_and_custom_plugin_creater.cc
fused_lookup_tables_op.cc)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32)
list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/convert/utils.h"
#include "paddle/fluid/inference/tensorrt/plugin/lookup_table.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class FusedLookupTablesOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
if (!engine_->with_dynamic_shape()) {
PADDLE_THROW(
platform::errors::Fatal("lookup_table_op must with dynamic shape"));
}
framework::OpDesc op_desc(op, nullptr);
auto ids_name = op_desc.Input("Ids").front();
auto w_name = op_desc.Input("W").front();
auto output_name = op_desc.Output("Out").front();
bool enable_int8 = op_desc.HasAttr("enable_int8");
std::vector<nvinfer1::ITensor*> plugin_inputs;
auto ids_dims = engine_->GetITensor(ids_name)->getDimensions();
if (ids_dims.d[ids_dims.nbDims - 1] == 1) {
nvinfer1::Dims new_ids_dims;
new_ids_dims.nbDims = ids_dims.nbDims - 1;
for (int i = 0; i < ids_dims.nbDims - 1; i++) {
new_ids_dims.d[i] = 0;
}
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *(engine_->GetITensor(ids_name)));
reshape_layer->setReshapeDimensions(new_ids_dims);
reshape_layer->setName(
("lookup_table: Shuffle (Output: " + output_name + ")").c_str());
plugin_inputs.push_back(reshape_layer->getOutput(0));
} else {
plugin_inputs.push_back(engine_->GetITensor(ids_name));
}
TensorRTEngine::Weight weight;
auto* w_var = scope.FindVar(w_name);
auto* w_tensor = w_var->GetMutable<framework::LoDTensor>();
auto w_dims = w_tensor->dims();
weight = engine_->GetTrtWeight(w_name, *w_tensor);
auto weight_size = phi::product(w_dims);
bool output_fp16;
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
output_fp16 = false;
} else {
output_fp16 = true;
}
int32_t weight_width = static_cast<int32_t>(w_dims[1]);
std::vector<nvinfer1::PluginField> fields;
fields.emplace_back("lookup_table_weight",
weight.get().values,
GetPluginFieldType(weight.get().type),
static_cast<int32_t>(weight_size));
fields.emplace_back("lookup_table_weight_width",
&weight_width,
nvinfer1::PluginFieldType::kINT32,
1);
fields.emplace_back(
"output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1);
nvinfer1::PluginFieldCollection* plugin_ptr =
static_cast<nvinfer1::PluginFieldCollection*>(
malloc(sizeof(*plugin_ptr) +
fields.size() * sizeof(nvinfer1::PluginField)));
plugin_ptr->nbFields = static_cast<int>(fields.size());
plugin_ptr->fields = fields.data();
auto creator =
GetPluginRegistry()->getPluginCreator("LookupTablePluginDynamic", "1");
auto plugin_obj =
creator->createPlugin("LookupTablePluginDynamic", plugin_ptr);
auto plugin_layer = engine_->network()->addPluginV2(
plugin_inputs.data(), plugin_inputs.size(), *plugin_obj);
plugin_layer->setName(
("lookup_table: (Output: " + output_name + ")").c_str());
engine_->SetITensor(output_name, plugin_layer->getOutput(0));
free(plugin_ptr);
if (enable_int8) {
float out_scale =
PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold"));
engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), out_scale);
}
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(lookup_table, FusedLookupTablesOpConverter);
......@@ -2083,6 +2083,14 @@ struct SimpleOpTypeSetTeller : public Teller {
}
}
if (op_type == "lookup_table") {
if (!with_dynamic_shape) {
VLOG(3) << "the lookup_table does not support "
"static shape yet";
return false;
}
}
if (use_no_calib_int8) {
return int8_teller_set.count(op_type);
} else {
......@@ -2201,7 +2209,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"shape",
"squeeze2",
"unsqueeze2",
"layernorm_shift_partition"};
"layernorm_shift_partition",
"lookup_table"};
std::unordered_set<std::string> teller_set{
"mul",
"matmul",
......@@ -2312,7 +2321,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"squeeze2",
"unsqueeze2",
"fused_token_prune",
"layernorm_shift_partition"};
"layernorm_shift_partition",
"lookup_table"};
};
struct GenericPluginTeller : public Teller {
......
......@@ -33,7 +33,8 @@ list(
preln_residual_bias_plugin.cu
fused_token_prune_op_plugin.cu
layernorm_shift_partition_op.cu
generic_plugin.cu)
generic_plugin.cu
lookup_table.cu)
if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32)
list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu
many_emb_Layernorm_varseqlen_kernelMTron.cu
......
......@@ -14,8 +14,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_
#define PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_
#pragma once
#include <cublas_v2.h>
#include <cuda_fp16.h>
......@@ -220,5 +219,3 @@ inline nvinfer1::DataType fieldTypeToDataType(
} // namespace tensorrt
} // namespace inference
} // namespace paddle
#endif // PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
// AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -13,11 +14,10 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef COMMON_CUH
#define COMMON_CUH
#pragma once
#include "cublas_v2.h"
#include <cub/cub.cuh>
#include "cublas_v2.h"
using kv_float = cub::KeyValuePair<float, float>;
using kv_half = cub::KeyValuePair<half, half>;
......@@ -28,22 +28,22 @@ __device__ inline T rsqrt(const T& x);
template <>
__device__ inline float rsqrt(const float& x) {
return rsqrtf(x);
return rsqrtf(x);
}
__device__ inline kv_float operator+(const kv_float& a, const kv_float& b) {
return kv_float(a.key + b.key, a.value + b.value);
return kv_float(a.key + b.key, a.value + b.value);
}
// Half Operations
__device__ inline half2 __hadd2_with_fallback(const half2 a, const half2 b) {
#if __CUDA_ARCH__ >= 530
return __hadd2(a, b);
return __hadd2(a, b);
#else
float2 out {};
out.x = __half2float(a.x) + __half2float(b.x);
out.y = __half2float(a.y) + __half2float(b.y);
return __float22half2_rn(out);
float2 out{};
out.x = __half2float(a.x) + __half2float(b.x);
out.y = __half2float(a.y) + __half2float(b.y);
return __float22half2_rn(out);
#endif
}
#if __CUDA_ARCH__ < 530
......@@ -53,14 +53,14 @@ template <typename T>
__device__ inline T operator*(const T& a, const T& b);
template <>
__device__ inline half2 operator+(const half2& a, const half2& b) {
return __hadd2_with_fallback(a, b);
return __hadd2_with_fallback(a, b);
}
template <>
__device__ inline half2 operator*(const half2& a, const half2& b) {
float2 out {};
out.x = __half2float(a.x) * __half2float(b.x);
out.y = __half2float(a.y) * __half2float(b.y);
return __float22half2_rn(out);
float2 out{};
out.x = __half2float(a.x) * __half2float(b.x);
out.y = __half2float(a.y) * __half2float(b.y);
return __float22half2_rn(out);
}
template <typename T>
__device__ inline T operator+(const T& a, const T& b);
......@@ -74,70 +74,73 @@ template <typename T>
__device__ inline T operator*(const T& a, const T& b);
template <>
__device__ inline half operator+(const half& a, const half& b) {
return __float2half(__half2float(a) + __half2float(b));
return __float2half(__half2float(a) + __half2float(b));
}
template <>
__device__ inline half& operator+=(half& a, const half& b) {
a = __float2half(__half2float(a) + __half2float(b));
return a;
a = __float2half(__half2float(a) + __half2float(b));
return a;
}
template <>
__device__ inline half operator-(const half& a, const half& b) {
return __float2half(__half2float(a) - __half2float(b));
return __float2half(__half2float(a) - __half2float(b));
}
template <>
__device__ inline half operator*(const half& a, const half& b) {
return __float2half(__half2float(a) * __half2float(b));
return __float2half(__half2float(a) * __half2float(b));
}
template <>
__device__ inline half operator/(const half& a, const half& b) {
return __float2half(__half2float(a) / __half2float(b));
return __float2half(__half2float(a) / __half2float(b));
}
#endif
template <>
__device__ inline half rsqrt(const half& x) {
#if __CUDA_ARCH__ >= 530
return hrsqrt(x);
return hrsqrt(x);
#else
return __float2half(rsqrt(__half2float(x)));
return __float2half(rsqrt(__half2float(x)));
#endif
}
__device__ inline kv_half operator+(const kv_half& a, const kv_half& b) {
const half2 a2 = __halves2half2(a.key, a.value);
const half2 b2 = __halves2half2(b.key, b.value);
const half2 res = __hadd2_with_fallback(a2, b2);
return kv_half(res.x, res.y);
const half2 a2 = __halves2half2(a.key, a.value);
const half2 b2 = __halves2half2(b.key, b.value);
const half2 res = __hadd2_with_fallback(a2, b2);
return kv_half(res.x, res.y);
}
__device__ inline kv_half2 operator+(const kv_half2& a, const kv_half2& b) {
return kv_half2(__hadd2_with_fallback(a.key, b.key), __hadd2_with_fallback(a.value, b.value));
return kv_half2(__hadd2_with_fallback(a.key, b.key),
__hadd2_with_fallback(a.value, b.value));
}
// Helper Functions
template <typename T>
using kvp = cub::KeyValuePair<T, T>;
template <typename T, typename R, typename P, int TPB>
__device__ inline void layerNorm(
const kvp<R>& threadData, const int ld, const int offset, const P* beta, const P* gamma, T* output) {
// Assuming threadData is already divided by ld
using BlockReduce = cub::BlockReduce<kvp<R>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ R mu; // mean
__shared__ R rsigma; // 1 / std.dev.
const auto sumKV = BlockReduce(temp_storage).Reduce(threadData, cub::Sum());
if (threadIdx.x == 0) {
mu = sumKV.key;
rsigma = rsqrt(sumKV.value - mu * mu);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const R val = output[idx];
const R g(gamma[i]);
const R b(beta[i]);
output[idx] = g * (val - mu) * rsigma + b;
}
__device__ inline void layerNorm(const kvp<R>& threadData,
const int ld,
const int offset,
const P* beta,
const P* gamma,
T* output) {
// Assuming threadData is already divided by ld
using BlockReduce = cub::BlockReduce<kvp<R>, TPB>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ R mu; // mean
__shared__ R rsigma; // 1 / std.dev.
const auto sumKV = BlockReduce(temp_storage).Reduce(threadData, cub::Sum());
if (threadIdx.x == 0) {
mu = sumKV.key;
rsigma = rsqrt(sumKV.value - mu * mu);
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int idx = offset + i;
const R val = output[idx];
const R g(gamma[i]);
const R b(beta[i]);
output[idx] = g * (val - mu) * rsigma + b;
}
}
#endif // #ifndef COMMON_CUH
......@@ -14,8 +14,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_
#define PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_
#pragma once
#include <cuda_runtime.h>
#include <cstring>
#include <iostream>
......@@ -60,4 +60,3 @@ class BaseCreator : public IPluginCreator {
};
} // namespace nvinfer1
#endif // PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/inference/tensorrt/plugin/lookup_table.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
char const* PLUGINVERSION{"1"};
char const* LOOKUPTABLEPLUGINNAME{"LookupTablePluginDynamic"};
template <typename T, unsigned TPB>
__global__ void lookup_table_kernel(int weight_height,
int32_t const* inputIds,
T const* wordEmb,
int32_t const wordSize,
T* output) {
// 1. lookup word and token of the block
// blockIdx.x = position in the sequence
// blockIdx.y = batch
// gridDim.x = S
// gridDim.y = B
__shared__ int wordId;
int32_t const seqPos = blockIdx.x + blockIdx.y * gridDim.x;
if (threadIdx.x == 0) {
wordId = inputIds[seqPos];
}
__syncthreads();
// 2. load word embeddings and add them toghether
// offset into embeddings is given by wordId * hidden_size
int32_t const woffset = wordId * weight_height;
// the output offset is given by b * (S*hidden_size) + s * hidden_size
int32_t const outOffset = seqPos * weight_height;
if (wordId >= 0 && wordId < wordSize) {
for (int it = threadIdx.x; it < weight_height; it += TPB) {
T const w(wordEmb[woffset + it]);
output[outOffset + it] = w;
}
} else {
printf(
"Error!!!!!!(LookupTablePlugin): ID cannot be lookup "
"table: ID < 0 or ID > max ");
return;
}
}
template <typename T>
int lookup_table(cudaStream_t stream,
int weight_height,
int B,
int S,
int32_t const* inputIds,
T const* wordEmb,
int32_t const wordSize,
T* output) {
constexpr int tpb = 256;
dim3 const grid(S, B, 1);
dim3 const block(tpb, 1, 1);
lookup_table_kernel<T, tpb><<<grid, block, 0, stream>>>(
weight_height, inputIds, wordEmb, wordSize, output);
return 0;
}
// Static class fields initialization
nvinfer1::PluginFieldCollection LookupTablePluginDynamicCreator::mFC{};
std::vector<nvinfer1::PluginField>
LookupTablePluginDynamicCreator::mPluginAttributes;
LookupTablePluginDynamic::LookupTablePluginDynamic(
nvinfer1::DataType const type,
void* weight_dev,
int32_t weight_size,
int32_t width)
: mType(type),
mWeightDev(weight_dev),
mWeightSize(weight_size),
mWeightWidth(width) {}
LookupTablePluginDynamic::LookupTablePluginDynamic(void const* data,
size_t length) {
// Deserialize in the same order as serialization
deserialize_value(&data, &length, &mType);
deserialize_value(&data, &length, &mWeightSize);
deserialize_value(&data, &length, &mWeightWidth);
char const* d = static_cast<char const*>(data);
cudaMalloc(&mWeightDev, mWeightSize * sizeof(mType));
cudaMemcpy(
mWeightDev, d, mWeightSize * sizeof(mType), cudaMemcpyHostToDevice);
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* LookupTablePluginDynamic::clone()
const noexcept {
auto p = new LookupTablePluginDynamic(
mType, mWeightDev, mWeightSize, mWeightWidth);
p->setPluginNamespace(mNamespace.c_str());
return p;
}
nvinfer1::DimsExprs LookupTablePluginDynamic::getOutputDimensions(
int32_t outputIndex,
nvinfer1::DimsExprs const* inputs,
int32_t nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept {
nvinfer1::DimsExprs ret;
ret.nbDims = inputs[0].nbDims + 1;
for (int i = 0; i < inputs[0].nbDims; ++i) {
ret.d[i] = inputs[0].d[i];
}
ret.d[inputs[0].nbDims] = exprBuilder.constant(mWeightWidth);
return ret;
}
bool LookupTablePluginDynamic::supportsFormatCombination(
int32_t pos,
nvinfer1::PluginTensorDesc const* inOut,
int32_t nbInputs,
int32_t nbOutputs) noexcept {
nvinfer1::PluginTensorDesc const& desc = inOut[pos];
if (desc.format != nvinfer1::TensorFormat::kLINEAR) {
return false;
}
if (pos == 0) {
return desc.type == nvinfer1::DataType::kINT32;
}
if (pos == 1) {
if (mType == nvinfer1::DataType::kFLOAT) {
return desc.type == nvinfer1::DataType::kFLOAT;
} else {
return desc.type == nvinfer1::DataType::kHALF;
}
}
}
void LookupTablePluginDynamic::configurePlugin(
nvinfer1::DynamicPluginTensorDesc const* inputs,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* outputs,
int32_t nbOutputs) noexcept {}
size_t LookupTablePluginDynamic::getWorkspaceSize(
nvinfer1::PluginTensorDesc const* inputs,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs,
int32_t nbOutputs) const noexcept {
return 0;
}
int32_t LookupTablePluginDynamic::enqueue(
nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept {
int32_t const batchSize = inputDesc->dims.d[0];
int32_t S;
if (inputDesc->dims.nbDims == 1) {
S = 1;
} else {
S = inputDesc->dims.d[1];
}
int32_t mWeightHeight = mWeightSize / mWeightWidth;
int32_t status = STATUS_FAILURE;
auto const inputIds = static_cast<int32_t const*>(inputs[0]);
if (mType == nvinfer1::DataType::kFLOAT) {
auto output = static_cast<float*>(outputs[0]);
auto const Weight = static_cast<const float*>(mWeightDev);
status = lookup_table<float>(stream,
static_cast<int32_t>(mWeightWidth),
batchSize,
S,
inputIds,
Weight,
mWeightHeight,
output);
} else if (mType == nvinfer1::DataType::kHALF) {
auto output = static_cast<half*>(outputs[0]);
auto const Weight = static_cast<const half*>(mWeightDev);
status = lookup_table<half>(stream,
static_cast<int32_t>(mWeightWidth),
batchSize,
S,
inputIds,
Weight,
mWeightHeight,
output);
}
return status;
}
// IPluginV2Ext Methods
nvinfer1::DataType LookupTablePluginDynamic::getOutputDataType(
int32_t index,
nvinfer1::DataType const* inputTypes,
int32_t nbInputs) const noexcept {
if (index == 0) {
assert(mType == nvinfer1::DataType::kHALF ||
mType == nvinfer1::DataType::kFLOAT);
return mType;
}
}
// IPluginV2 Methods
char const* LookupTablePluginDynamic::getPluginType() const noexcept {
return LOOKUPTABLEPLUGINNAME;
}
char const* LookupTablePluginDynamic::getPluginVersion() const noexcept {
return PLUGINVERSION;
}
int32_t LookupTablePluginDynamic::getNbOutputs() const noexcept { return 1; }
int32_t LookupTablePluginDynamic::initialize() noexcept { return 0; }
void LookupTablePluginDynamic::terminate() noexcept { cudaFree(mWeightDev); }
size_t LookupTablePluginDynamic::getSerializationSize() const noexcept {
size_t const wordSize = getElementSize(mType);
return sizeof(mType) //
+ sizeof(mWeightSize) //
+ sizeof(mWeightWidth) //
+ wordSize * mWeightSize; //
}
void LookupTablePluginDynamic::serialize(void* buffer) const noexcept {
serialize_value(&buffer, mType);
serialize_value(&buffer, mWeightSize);
serialize_value(&buffer, mWeightWidth);
char* d = static_cast<char*>(buffer);
size_t const wordSize = getElementSize(mType);
serFromDev(&d, static_cast<char*>(mWeightDev), mWeightSize * wordSize);
}
void LookupTablePluginDynamic::destroy() noexcept {
// This gets called when the network containing plugin is destroyed
delete this;
}
void LookupTablePluginDynamic::setPluginNamespace(
char const* libNamespace) noexcept {
mNamespace = libNamespace;
}
char const* LookupTablePluginDynamic::getPluginNamespace() const noexcept {
return mNamespace.c_str();
}
LookupTablePluginDynamicCreator::LookupTablePluginDynamicCreator() {}
char const* LookupTablePluginDynamicCreator::getPluginName() const noexcept {
return LOOKUPTABLEPLUGINNAME;
}
char const* LookupTablePluginDynamicCreator::getPluginVersion() const noexcept {
return PLUGINVERSION;
}
nvinfer1::PluginFieldCollection const*
LookupTablePluginDynamicCreator::getFieldNames() noexcept {
return &mFC;
}
bool initializeFields(nvinfer1::PluginFieldCollection const* fc,
nvinfer1::Weights* weight,
int32_t& mWeightWidth) { // NOLINT
bool output_fp16 = false;
for (int32_t i = 0; i < fc->nbFields; i++) {
std::string field_name(fc->fields[i].name);
if (field_name.compare("lookup_table_weight") == 0) {
weight->values = fc->fields[i].data;
weight->count = fc->fields[i].length;
weight->type = fieldTypeToDataType(fc->fields[i].type);
}
if (field_name.compare("lookup_table_weight_width") == 0) {
assert(fc->fields[i].type == nvinfer1::PluginFieldType::kINT32);
mWeightWidth = const_cast<int32_t*>(
static_cast<int32_t const*>(fc->fields[i].data))[0]; // NOLINT
}
if (field_name.compare("output_fp16") == 0) {
assert(fc->fields[i].type == nvinfer1::PluginFieldType::kINT32);
output_fp16 = static_cast<int32_t const*>(fc->fields[i].data)[0] != 0;
}
}
return output_fp16;
}
nvinfer1::IPluginV2* LookupTablePluginDynamicCreator::createPlugin(
char const* name, const nvinfer1::PluginFieldCollection* fc) noexcept {
nvinfer1::Weights weight;
int32_t mWeightWidth;
bool output_fp16 = initializeFields(fc, &weight, mWeightWidth);
nvinfer1::DataType type;
if (output_fp16) {
type = nvinfer1::DataType::kHALF;
} else {
type = nvinfer1::DataType::kFLOAT;
}
WeightsWithOwnership mWeight;
mWeight.convertAndCopy(weight, type);
void* cudaMem{nullptr};
cudaMalloc(&cudaMem, getWeightsSize(mWeight, type));
cudaMemcpy(cudaMem,
mWeight.values,
getWeightsSize(mWeight, type),
cudaMemcpyHostToDevice);
LookupTablePluginDynamic* p =
new LookupTablePluginDynamic(type, cudaMem, mWeight.count, mWeightWidth);
return p;
}
nvinfer1::IPluginV2* LookupTablePluginDynamicCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept {
return new LookupTablePluginDynamic(serialData, serialLength);
}
void LookupTablePluginDynamicCreator::setPluginNamespace(
char const* libNamespace) noexcept {
mNamespace = libNamespace;
}
char const* LookupTablePluginDynamicCreator::getPluginNamespace()
const noexcept {
return mNamespace.c_str();
}
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
#include "NvInferPlugin.h"
#include "NvInferRuntime.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace inference {
namespace tensorrt {
namespace plugin {
class LookupTablePluginDynamic : public nvinfer1::IPluginV2DynamicExt {
public:
LookupTablePluginDynamic(nvinfer1::DataType const type,
void* weight_dev,
int32_t weight_size,
int32_t width);
LookupTablePluginDynamic(void const* data, size_t length);
// It doesn't make sense to make EmbLayerNormVarSeqlenPlugin without
// arguments, so we delete default constructor.
LookupTablePluginDynamic() = delete;
// IPluginV2DynamicExt Methods
bool supportsFormatCombination(int32_t pos,
nvinfer1::PluginTensorDesc const* inOut,
int32_t nbInputs,
int32_t nbOutputs) noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs,
int32_t nbInputs,
nvinfer1::PluginTensorDesc const* outputs,
int32_t nbOutputs) const noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int32_t index,
nvinfer1::DataType const* inputTypes,
int32_t nbInputs) const noexcept override;
// IPluginV2 Methods
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(
int32_t outputIndex,
const nvinfer1::DimsExprs* inputs,
int32_t nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in,
int32_t nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out,
int32_t nbOutputs) noexcept override;
char const* getPluginType() const noexcept override;
int32_t getNbOutputs() const noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
char const* getPluginNamespace() const noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs,
void* const* outputs,
void* workspace,
cudaStream_t stream) noexcept override;
int32_t initialize() noexcept override;
void terminate() noexcept override;
char const* getPluginVersion() const noexcept override;
protected:
std::string mNamespace;
nvinfer1::DataType mType;
void* mWeightDev{nullptr};
int32_t mWeightSize;
int32_t mWeightWidth;
};
class LookupTablePluginDynamicCreator : public nvinfer1::IPluginCreator {
public:
LookupTablePluginDynamicCreator();
char const* getPluginName() const noexcept override;
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
void setPluginNamespace(char const* pluginNamespace) noexcept override;
char const* getPluginNamespace() const noexcept override;
nvinfer1::IPluginV2* createPlugin(
char const* name,
const nvinfer1::PluginFieldCollection* fc) noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::IPluginV2* deserializePlugin(char const* name,
void const* serialData,
size_t serialLength) noexcept override;
protected:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
REGISTER_TRT_PLUGIN_V2(LookupTablePluginDynamicCreator);
} // namespace plugin
} // namespace tensorrt
} // namespace inference
} // namespace paddle
......@@ -21,10 +21,7 @@
#include <vector>
#include "NvInfer.h"
#include "common/bertCommon.h"
#include "common/common.cuh"
#include "common/plugin.h"
#include "common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
namespace paddle {
......
......@@ -21,10 +21,7 @@
#include <vector>
#include "NvInfer.h"
#include "common/bertCommon.h"
#include "common/common.cuh"
#include "common/plugin.h"
#include "common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h"
namespace paddle {
......
......@@ -19,7 +19,6 @@
#include <cstring>
#include <vector>
#include "NvInfer.h"
#include "common/serialize.h"
namespace paddle {
namespace inference {
......
......@@ -18,7 +18,10 @@
#include <cuda.h>
#include "NvInferPlugin.h"
#include "NvInferRuntime.h"
#include "common/bertCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/enforce.h"
......
......@@ -228,7 +228,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
max_batch_size=4,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False)
yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5)
......@@ -238,7 +238,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest):
max_batch_size=4,
workspace_size=102400,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
precision_mode=paddle_infer.PrecisionType.Half,
use_static=False,
use_calib_mode=False)
if program_config.ops[0].type == 'lookup_table':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册