diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index df1858f6e8b602e2fd8eb987980cbb30fdc4860c..7eca5aa126d9048430ac1efb555c44cd676583c0 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2200,6 +2200,7 @@ USE_TRT_CONVERTER(fused_token_prune) USE_TRT_CONVERTER(layernorm_shift_partition) USE_TRT_CONVERTER(generic_plugin_creater) USE_TRT_CONVERTER(custom_plugin_creater) +USE_TRT_CONVERTER(lookup_table) #if PADDLE_WITH_CUSPARSELT && IS_TRT_VERSION_GE(8000) USE_TRT_CONVERTER(sparse_fc) USE_TRT_CONVERTER(sparse_multihead_matmul) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 9bc0ad9114bfa7fbbdf1b9658d6541fc849aff86..ed6508929ca1ffafe2f3ab1177fa581ad1646fa8 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -76,7 +76,8 @@ list( fill_constant_op.cc fused_token_prune_op.cc layernorm_shift_partition_op.cc - generic_and_custom_plugin_creater.cc) + generic_and_custom_plugin_creater.cc + fused_lookup_tables_op.cc) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc diff --git a/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc b/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d34d322ce173fab5d1fa3045bb8cb354fc6847bc --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/fused_lookup_tables_op.cc @@ -0,0 +1,123 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" +#include "paddle/fluid/inference/tensorrt/convert/utils.h" +#include "paddle/fluid/inference/tensorrt/plugin/lookup_table.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class FusedLookupTablesOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + if (!engine_->with_dynamic_shape()) { + PADDLE_THROW( + platform::errors::Fatal("lookup_table_op must with dynamic shape")); + } + + framework::OpDesc op_desc(op, nullptr); + auto ids_name = op_desc.Input("Ids").front(); + auto w_name = op_desc.Input("W").front(); + auto output_name = op_desc.Output("Out").front(); + bool enable_int8 = op_desc.HasAttr("enable_int8"); + std::vector plugin_inputs; + + auto ids_dims = engine_->GetITensor(ids_name)->getDimensions(); + if (ids_dims.d[ids_dims.nbDims - 1] == 1) { + nvinfer1::Dims new_ids_dims; + new_ids_dims.nbDims = ids_dims.nbDims - 1; + for (int i = 0; i < ids_dims.nbDims - 1; i++) { + new_ids_dims.d[i] = 0; + } + auto* reshape_layer = TRT_ENGINE_ADD_LAYER( + engine_, Shuffle, *(engine_->GetITensor(ids_name))); + reshape_layer->setReshapeDimensions(new_ids_dims); + reshape_layer->setName( + ("lookup_table: Shuffle (Output: " + output_name + ")").c_str()); + plugin_inputs.push_back(reshape_layer->getOutput(0)); + } else { + plugin_inputs.push_back(engine_->GetITensor(ids_name)); + } + + TensorRTEngine::Weight weight; + auto* w_var = scope.FindVar(w_name); + auto* w_tensor = w_var->GetMutable(); + auto w_dims = w_tensor->dims(); + weight = engine_->GetTrtWeight(w_name, *w_tensor); + auto weight_size = phi::product(w_dims); + bool output_fp16; + if (engine_->precision() == AnalysisConfig::Precision::kFloat32) { + output_fp16 = false; + } else { + output_fp16 = true; + } + + int32_t weight_width = static_cast(w_dims[1]); + + std::vector fields; + fields.emplace_back("lookup_table_weight", + weight.get().values, + GetPluginFieldType(weight.get().type), + static_cast(weight_size)); + fields.emplace_back("lookup_table_weight_width", + &weight_width, + nvinfer1::PluginFieldType::kINT32, + 1); + fields.emplace_back( + "output_fp16", &output_fp16, nvinfer1::PluginFieldType::kINT32, 1); + nvinfer1::PluginFieldCollection* plugin_ptr = + static_cast( + malloc(sizeof(*plugin_ptr) + + fields.size() * sizeof(nvinfer1::PluginField))); + plugin_ptr->nbFields = static_cast(fields.size()); + plugin_ptr->fields = fields.data(); + auto creator = + GetPluginRegistry()->getPluginCreator("LookupTablePluginDynamic", "1"); + auto plugin_obj = + creator->createPlugin("LookupTablePluginDynamic", plugin_ptr); + + auto plugin_layer = engine_->network()->addPluginV2( + plugin_inputs.data(), plugin_inputs.size(), *plugin_obj); + + plugin_layer->setName( + ("lookup_table: (Output: " + output_name + ")").c_str()); + engine_->SetITensor(output_name, plugin_layer->getOutput(0)); + free(plugin_ptr); + if (enable_int8) { + float out_scale = + PADDLE_GET_CONST(float, op_desc.GetAttr("out_threshold")); + engine_->SetTensorDynamicRange(plugin_layer->getOutput(0), out_scale); + } + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(lookup_table, FusedLookupTablesOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 5756efefa5b2f42f0184dbf0e0a62546b00c8135..a6eeffbb877d537d76ef15a4469337be51273aa6 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -2083,6 +2083,14 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "lookup_table") { + if (!with_dynamic_shape) { + VLOG(3) << "the lookup_table does not support " + "static shape yet"; + return false; + } + } + if (use_no_calib_int8) { return int8_teller_set.count(op_type); } else { @@ -2201,7 +2209,8 @@ struct SimpleOpTypeSetTeller : public Teller { "shape", "squeeze2", "unsqueeze2", - "layernorm_shift_partition"}; + "layernorm_shift_partition", + "lookup_table"}; std::unordered_set teller_set{ "mul", "matmul", @@ -2312,7 +2321,8 @@ struct SimpleOpTypeSetTeller : public Teller { "squeeze2", "unsqueeze2", "fused_token_prune", - "layernorm_shift_partition"}; + "layernorm_shift_partition", + "lookup_table"}; }; struct GenericPluginTeller : public Teller { diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt index b091ef42d8cc15ff6f0bae88d197e1b269df265b..1d5db4ee57f97163417c5c3a2237bd9978d4b55a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt @@ -33,7 +33,8 @@ list( preln_residual_bias_plugin.cu fused_token_prune_op_plugin.cu layernorm_shift_partition_op.cu - generic_plugin.cu) + generic_plugin.cu + lookup_table.cu) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7 AND NOT WIN32) list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu many_emb_Layernorm_varseqlen_kernelMTron.cu diff --git a/paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h b/paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h index fc07cdededc7b7a7ddb4269f9acf5231134e04de..56d4edd67936b3d656e6a137c5e1af7b555c82f5 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h +++ b/paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h @@ -14,8 +14,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_ -#define PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_ +#pragma once #include #include @@ -220,5 +219,3 @@ inline nvinfer1::DataType fieldTypeToDataType( } // namespace tensorrt } // namespace inference } // namespace paddle - -#endif // PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_BERTCOMMON_H_ diff --git a/paddle/fluid/inference/tensorrt/plugin/common/common.cuh b/paddle/fluid/inference/tensorrt/plugin/common/common.cuh index 6e155de44d095fe7c55f86479b6b53f0a6263541..10bf23fbc531bd963d8a6894af23c74a0586c00f 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/common.cuh +++ b/paddle/fluid/inference/tensorrt/plugin/common/common.cuh @@ -1,5 +1,6 @@ // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & +// AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -13,11 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef COMMON_CUH -#define COMMON_CUH +#pragma once -#include "cublas_v2.h" #include +#include "cublas_v2.h" using kv_float = cub::KeyValuePair; using kv_half = cub::KeyValuePair; @@ -28,22 +28,22 @@ __device__ inline T rsqrt(const T& x); template <> __device__ inline float rsqrt(const float& x) { - return rsqrtf(x); + return rsqrtf(x); } __device__ inline kv_float operator+(const kv_float& a, const kv_float& b) { - return kv_float(a.key + b.key, a.value + b.value); + return kv_float(a.key + b.key, a.value + b.value); } // Half Operations __device__ inline half2 __hadd2_with_fallback(const half2 a, const half2 b) { #if __CUDA_ARCH__ >= 530 - return __hadd2(a, b); + return __hadd2(a, b); #else - float2 out {}; - out.x = __half2float(a.x) + __half2float(b.x); - out.y = __half2float(a.y) + __half2float(b.y); - return __float22half2_rn(out); + float2 out{}; + out.x = __half2float(a.x) + __half2float(b.x); + out.y = __half2float(a.y) + __half2float(b.y); + return __float22half2_rn(out); #endif } #if __CUDA_ARCH__ < 530 @@ -53,14 +53,14 @@ template __device__ inline T operator*(const T& a, const T& b); template <> __device__ inline half2 operator+(const half2& a, const half2& b) { - return __hadd2_with_fallback(a, b); + return __hadd2_with_fallback(a, b); } template <> __device__ inline half2 operator*(const half2& a, const half2& b) { - float2 out {}; - out.x = __half2float(a.x) * __half2float(b.x); - out.y = __half2float(a.y) * __half2float(b.y); - return __float22half2_rn(out); + float2 out{}; + out.x = __half2float(a.x) * __half2float(b.x); + out.y = __half2float(a.y) * __half2float(b.y); + return __float22half2_rn(out); } template __device__ inline T operator+(const T& a, const T& b); @@ -74,70 +74,73 @@ template __device__ inline T operator*(const T& a, const T& b); template <> __device__ inline half operator+(const half& a, const half& b) { - return __float2half(__half2float(a) + __half2float(b)); + return __float2half(__half2float(a) + __half2float(b)); } template <> __device__ inline half& operator+=(half& a, const half& b) { - a = __float2half(__half2float(a) + __half2float(b)); - return a; + a = __float2half(__half2float(a) + __half2float(b)); + return a; } template <> __device__ inline half operator-(const half& a, const half& b) { - return __float2half(__half2float(a) - __half2float(b)); + return __float2half(__half2float(a) - __half2float(b)); } template <> __device__ inline half operator*(const half& a, const half& b) { - return __float2half(__half2float(a) * __half2float(b)); + return __float2half(__half2float(a) * __half2float(b)); } template <> __device__ inline half operator/(const half& a, const half& b) { - return __float2half(__half2float(a) / __half2float(b)); + return __float2half(__half2float(a) / __half2float(b)); } #endif template <> __device__ inline half rsqrt(const half& x) { #if __CUDA_ARCH__ >= 530 - return hrsqrt(x); + return hrsqrt(x); #else - return __float2half(rsqrt(__half2float(x))); + return __float2half(rsqrt(__half2float(x))); #endif } __device__ inline kv_half operator+(const kv_half& a, const kv_half& b) { - const half2 a2 = __halves2half2(a.key, a.value); - const half2 b2 = __halves2half2(b.key, b.value); - const half2 res = __hadd2_with_fallback(a2, b2); - return kv_half(res.x, res.y); + const half2 a2 = __halves2half2(a.key, a.value); + const half2 b2 = __halves2half2(b.key, b.value); + const half2 res = __hadd2_with_fallback(a2, b2); + return kv_half(res.x, res.y); } __device__ inline kv_half2 operator+(const kv_half2& a, const kv_half2& b) { - return kv_half2(__hadd2_with_fallback(a.key, b.key), __hadd2_with_fallback(a.value, b.value)); + return kv_half2(__hadd2_with_fallback(a.key, b.key), + __hadd2_with_fallback(a.value, b.value)); } // Helper Functions template using kvp = cub::KeyValuePair; template -__device__ inline void layerNorm( - const kvp& threadData, const int ld, const int offset, const P* beta, const P* gamma, T* output) { - // Assuming threadData is already divided by ld - using BlockReduce = cub::BlockReduce, TPB>; - __shared__ typename BlockReduce::TempStorage temp_storage; - __shared__ R mu; // mean - __shared__ R rsigma; // 1 / std.dev. - const auto sumKV = BlockReduce(temp_storage).Reduce(threadData, cub::Sum()); - if (threadIdx.x == 0) { - mu = sumKV.key; - rsigma = rsqrt(sumKV.value - mu * mu); - } - __syncthreads(); - for (int i = threadIdx.x; i < ld; i += TPB) { - const int idx = offset + i; - const R val = output[idx]; - const R g(gamma[i]); - const R b(beta[i]); - output[idx] = g * (val - mu) * rsigma + b; - } +__device__ inline void layerNorm(const kvp& threadData, + const int ld, + const int offset, + const P* beta, + const P* gamma, + T* output) { + // Assuming threadData is already divided by ld + using BlockReduce = cub::BlockReduce, TPB>; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ R mu; // mean + __shared__ R rsigma; // 1 / std.dev. + const auto sumKV = BlockReduce(temp_storage).Reduce(threadData, cub::Sum()); + if (threadIdx.x == 0) { + mu = sumKV.key; + rsigma = rsqrt(sumKV.value - mu * mu); + } + __syncthreads(); + for (int i = threadIdx.x; i < ld; i += TPB) { + const int idx = offset + i; + const R val = output[idx]; + const R g(gamma[i]); + const R b(beta[i]); + output[idx] = g * (val - mu) * rsigma + b; + } } - -#endif // #ifndef COMMON_CUH diff --git a/paddle/fluid/inference/tensorrt/plugin/common/plugin.h b/paddle/fluid/inference/tensorrt/plugin/common/plugin.h index de8d7dc2deafb0821c822bd21659af63b581e53b..1491f98512873fcc26488852ff853a648f6670ce 100644 --- a/paddle/fluid/inference/tensorrt/plugin/common/plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/common/plugin.h @@ -14,8 +14,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_ -#define PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_ +#pragma once + #include #include #include @@ -60,4 +60,3 @@ class BaseCreator : public IPluginCreator { }; } // namespace nvinfer1 -#endif // PADDLE_FLUID_INFERENCE_TENSORRT_PLUGIN_COMMON_PLUGIN_H_ diff --git a/paddle/fluid/inference/tensorrt/plugin/lookup_table.cu b/paddle/fluid/inference/tensorrt/plugin/lookup_table.cu new file mode 100644 index 0000000000000000000000000000000000000000..41886d24aa144f593bc231790d111ea08b116794 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/lookup_table.cu @@ -0,0 +1,346 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/inference/tensorrt/plugin/lookup_table.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +char const* PLUGINVERSION{"1"}; +char const* LOOKUPTABLEPLUGINNAME{"LookupTablePluginDynamic"}; + +template +__global__ void lookup_table_kernel(int weight_height, + int32_t const* inputIds, + T const* wordEmb, + int32_t const wordSize, + T* output) { + // 1. lookup word and token of the block + // blockIdx.x = position in the sequence + // blockIdx.y = batch + // gridDim.x = S + // gridDim.y = B + __shared__ int wordId; + int32_t const seqPos = blockIdx.x + blockIdx.y * gridDim.x; + if (threadIdx.x == 0) { + wordId = inputIds[seqPos]; + } + __syncthreads(); + + // 2. load word embeddings and add them toghether + // offset into embeddings is given by wordId * hidden_size + int32_t const woffset = wordId * weight_height; + // the output offset is given by b * (S*hidden_size) + s * hidden_size + int32_t const outOffset = seqPos * weight_height; + if (wordId >= 0 && wordId < wordSize) { + for (int it = threadIdx.x; it < weight_height; it += TPB) { + T const w(wordEmb[woffset + it]); + output[outOffset + it] = w; + } + } else { + printf( + "Error!!!!!!(LookupTablePlugin): ID cannot be lookup " + "table: ID < 0 or ID > max "); + return; + } +} + +template +int lookup_table(cudaStream_t stream, + int weight_height, + int B, + int S, + int32_t const* inputIds, + T const* wordEmb, + int32_t const wordSize, + T* output) { + constexpr int tpb = 256; + dim3 const grid(S, B, 1); + dim3 const block(tpb, 1, 1); + lookup_table_kernel<<>>( + weight_height, inputIds, wordEmb, wordSize, output); + return 0; +} + +// Static class fields initialization +nvinfer1::PluginFieldCollection LookupTablePluginDynamicCreator::mFC{}; +std::vector + LookupTablePluginDynamicCreator::mPluginAttributes; + +LookupTablePluginDynamic::LookupTablePluginDynamic( + nvinfer1::DataType const type, + void* weight_dev, + int32_t weight_size, + int32_t width) + : mType(type), + mWeightDev(weight_dev), + mWeightSize(weight_size), + mWeightWidth(width) {} + +LookupTablePluginDynamic::LookupTablePluginDynamic(void const* data, + size_t length) { + // Deserialize in the same order as serialization + deserialize_value(&data, &length, &mType); + deserialize_value(&data, &length, &mWeightSize); + deserialize_value(&data, &length, &mWeightWidth); + char const* d = static_cast(data); + cudaMalloc(&mWeightDev, mWeightSize * sizeof(mType)); + cudaMemcpy( + mWeightDev, d, mWeightSize * sizeof(mType), cudaMemcpyHostToDevice); +} + +// IPluginV2DynamicExt Methods +nvinfer1::IPluginV2DynamicExt* LookupTablePluginDynamic::clone() + const noexcept { + auto p = new LookupTablePluginDynamic( + mType, mWeightDev, mWeightSize, mWeightWidth); + p->setPluginNamespace(mNamespace.c_str()); + return p; +} + +nvinfer1::DimsExprs LookupTablePluginDynamic::getOutputDimensions( + int32_t outputIndex, + nvinfer1::DimsExprs const* inputs, + int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept { + nvinfer1::DimsExprs ret; + ret.nbDims = inputs[0].nbDims + 1; + for (int i = 0; i < inputs[0].nbDims; ++i) { + ret.d[i] = inputs[0].d[i]; + } + ret.d[inputs[0].nbDims] = exprBuilder.constant(mWeightWidth); + return ret; +} + +bool LookupTablePluginDynamic::supportsFormatCombination( + int32_t pos, + nvinfer1::PluginTensorDesc const* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept { + nvinfer1::PluginTensorDesc const& desc = inOut[pos]; + if (desc.format != nvinfer1::TensorFormat::kLINEAR) { + return false; + } + if (pos == 0) { + return desc.type == nvinfer1::DataType::kINT32; + } + if (pos == 1) { + if (mType == nvinfer1::DataType::kFLOAT) { + return desc.type == nvinfer1::DataType::kFLOAT; + } else { + return desc.type == nvinfer1::DataType::kHALF; + } + } +} + +void LookupTablePluginDynamic::configurePlugin( + nvinfer1::DynamicPluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* outputs, + int32_t nbOutputs) noexcept {} + +size_t LookupTablePluginDynamic::getWorkspaceSize( + nvinfer1::PluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept { + return 0; +} + +int32_t LookupTablePluginDynamic::enqueue( + nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept { + int32_t const batchSize = inputDesc->dims.d[0]; + int32_t S; + if (inputDesc->dims.nbDims == 1) { + S = 1; + } else { + S = inputDesc->dims.d[1]; + } + int32_t mWeightHeight = mWeightSize / mWeightWidth; + int32_t status = STATUS_FAILURE; + auto const inputIds = static_cast(inputs[0]); + if (mType == nvinfer1::DataType::kFLOAT) { + auto output = static_cast(outputs[0]); + auto const Weight = static_cast(mWeightDev); + status = lookup_table(stream, + static_cast(mWeightWidth), + batchSize, + S, + inputIds, + Weight, + mWeightHeight, + output); + } else if (mType == nvinfer1::DataType::kHALF) { + auto output = static_cast(outputs[0]); + auto const Weight = static_cast(mWeightDev); + status = lookup_table(stream, + static_cast(mWeightWidth), + batchSize, + S, + inputIds, + Weight, + mWeightHeight, + output); + } + return status; +} + +// IPluginV2Ext Methods +nvinfer1::DataType LookupTablePluginDynamic::getOutputDataType( + int32_t index, + nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept { + if (index == 0) { + assert(mType == nvinfer1::DataType::kHALF || + mType == nvinfer1::DataType::kFLOAT); + return mType; + } +} + +// IPluginV2 Methods +char const* LookupTablePluginDynamic::getPluginType() const noexcept { + return LOOKUPTABLEPLUGINNAME; +} + +char const* LookupTablePluginDynamic::getPluginVersion() const noexcept { + return PLUGINVERSION; +} + +int32_t LookupTablePluginDynamic::getNbOutputs() const noexcept { return 1; } + +int32_t LookupTablePluginDynamic::initialize() noexcept { return 0; } + +void LookupTablePluginDynamic::terminate() noexcept { cudaFree(mWeightDev); } + +size_t LookupTablePluginDynamic::getSerializationSize() const noexcept { + size_t const wordSize = getElementSize(mType); + return sizeof(mType) // + + sizeof(mWeightSize) // + + sizeof(mWeightWidth) // + + wordSize * mWeightSize; // +} + +void LookupTablePluginDynamic::serialize(void* buffer) const noexcept { + serialize_value(&buffer, mType); + serialize_value(&buffer, mWeightSize); + serialize_value(&buffer, mWeightWidth); + char* d = static_cast(buffer); + size_t const wordSize = getElementSize(mType); + serFromDev(&d, static_cast(mWeightDev), mWeightSize * wordSize); +} + +void LookupTablePluginDynamic::destroy() noexcept { + // This gets called when the network containing plugin is destroyed + delete this; +} + +void LookupTablePluginDynamic::setPluginNamespace( + char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* LookupTablePluginDynamic::getPluginNamespace() const noexcept { + return mNamespace.c_str(); +} + +LookupTablePluginDynamicCreator::LookupTablePluginDynamicCreator() {} + +char const* LookupTablePluginDynamicCreator::getPluginName() const noexcept { + return LOOKUPTABLEPLUGINNAME; +} + +char const* LookupTablePluginDynamicCreator::getPluginVersion() const noexcept { + return PLUGINVERSION; +} + +nvinfer1::PluginFieldCollection const* +LookupTablePluginDynamicCreator::getFieldNames() noexcept { + return &mFC; +} + +bool initializeFields(nvinfer1::PluginFieldCollection const* fc, + nvinfer1::Weights* weight, + int32_t& mWeightWidth) { // NOLINT + bool output_fp16 = false; + for (int32_t i = 0; i < fc->nbFields; i++) { + std::string field_name(fc->fields[i].name); + if (field_name.compare("lookup_table_weight") == 0) { + weight->values = fc->fields[i].data; + weight->count = fc->fields[i].length; + weight->type = fieldTypeToDataType(fc->fields[i].type); + } + if (field_name.compare("lookup_table_weight_width") == 0) { + assert(fc->fields[i].type == nvinfer1::PluginFieldType::kINT32); + mWeightWidth = const_cast( + static_cast(fc->fields[i].data))[0]; // NOLINT + } + if (field_name.compare("output_fp16") == 0) { + assert(fc->fields[i].type == nvinfer1::PluginFieldType::kINT32); + output_fp16 = static_cast(fc->fields[i].data)[0] != 0; + } + } + return output_fp16; +} + +nvinfer1::IPluginV2* LookupTablePluginDynamicCreator::createPlugin( + char const* name, const nvinfer1::PluginFieldCollection* fc) noexcept { + nvinfer1::Weights weight; + int32_t mWeightWidth; + bool output_fp16 = initializeFields(fc, &weight, mWeightWidth); + nvinfer1::DataType type; + if (output_fp16) { + type = nvinfer1::DataType::kHALF; + } else { + type = nvinfer1::DataType::kFLOAT; + } + WeightsWithOwnership mWeight; + mWeight.convertAndCopy(weight, type); + void* cudaMem{nullptr}; + cudaMalloc(&cudaMem, getWeightsSize(mWeight, type)); + cudaMemcpy(cudaMem, + mWeight.values, + getWeightsSize(mWeight, type), + cudaMemcpyHostToDevice); + LookupTablePluginDynamic* p = + new LookupTablePluginDynamic(type, cudaMem, mWeight.count, mWeightWidth); + return p; +} + +nvinfer1::IPluginV2* LookupTablePluginDynamicCreator::deserializePlugin( + char const* name, void const* serialData, size_t serialLength) noexcept { + return new LookupTablePluginDynamic(serialData, serialLength); +} + +void LookupTablePluginDynamicCreator::setPluginNamespace( + char const* libNamespace) noexcept { + mNamespace = libNamespace; +} + +char const* LookupTablePluginDynamicCreator::getPluginNamespace() + const noexcept { + return mNamespace.c_str(); +} + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/lookup_table.h b/paddle/fluid/inference/tensorrt/plugin/lookup_table.h new file mode 100644 index 0000000000000000000000000000000000000000..84f6282d4f239e3c4240438bb01539fe00bc3f39 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/plugin/lookup_table.h @@ -0,0 +1,126 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "NvInferPlugin.h" +#include "NvInferRuntime.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h" +#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace inference { +namespace tensorrt { +namespace plugin { + +class LookupTablePluginDynamic : public nvinfer1::IPluginV2DynamicExt { + public: + LookupTablePluginDynamic(nvinfer1::DataType const type, + void* weight_dev, + int32_t weight_size, + int32_t width); + + LookupTablePluginDynamic(void const* data, size_t length); + + // It doesn't make sense to make EmbLayerNormVarSeqlenPlugin without + // arguments, so we delete default constructor. + LookupTablePluginDynamic() = delete; + + // IPluginV2DynamicExt Methods + bool supportsFormatCombination(int32_t pos, + nvinfer1::PluginTensorDesc const* inOut, + int32_t nbInputs, + int32_t nbOutputs) noexcept override; + size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, + int32_t nbInputs, + nvinfer1::PluginTensorDesc const* outputs, + int32_t nbOutputs) const noexcept override; + + // IPluginV2Ext Methods + nvinfer1::DataType getOutputDataType( + int32_t index, + nvinfer1::DataType const* inputTypes, + int32_t nbInputs) const noexcept override; + + // IPluginV2 Methods + nvinfer1::IPluginV2DynamicExt* clone() const noexcept override; + nvinfer1::DimsExprs getOutputDimensions( + int32_t outputIndex, + const nvinfer1::DimsExprs* inputs, + int32_t nbInputs, + nvinfer1::IExprBuilder& exprBuilder) noexcept override; + void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, + int32_t nbInputs, + nvinfer1::DynamicPluginTensorDesc const* out, + int32_t nbOutputs) noexcept override; + char const* getPluginType() const noexcept override; + int32_t getNbOutputs() const noexcept override; + size_t getSerializationSize() const noexcept override; + void serialize(void* buffer) const noexcept override; + void destroy() noexcept override; + char const* getPluginNamespace() const noexcept override; + void setPluginNamespace(char const* pluginNamespace) noexcept override; + int32_t enqueue(nvinfer1::PluginTensorDesc const* inputDesc, + nvinfer1::PluginTensorDesc const* outputDesc, + void const* const* inputs, + void* const* outputs, + void* workspace, + cudaStream_t stream) noexcept override; + + int32_t initialize() noexcept override; + void terminate() noexcept override; + char const* getPluginVersion() const noexcept override; + + protected: + std::string mNamespace; + nvinfer1::DataType mType; + void* mWeightDev{nullptr}; + int32_t mWeightSize; + int32_t mWeightWidth; +}; + +class LookupTablePluginDynamicCreator : public nvinfer1::IPluginCreator { + public: + LookupTablePluginDynamicCreator(); + + char const* getPluginName() const noexcept override; + + const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override; + + void setPluginNamespace(char const* pluginNamespace) noexcept override; + + char const* getPluginNamespace() const noexcept override; + + nvinfer1::IPluginV2* createPlugin( + char const* name, + const nvinfer1::PluginFieldCollection* fc) noexcept override; + char const* getPluginVersion() const noexcept override; + nvinfer1::IPluginV2* deserializePlugin(char const* name, + void const* serialData, + size_t serialLength) noexcept override; + + protected: + static nvinfer1::PluginFieldCollection mFC; + static std::vector mPluginAttributes; + std::string mNamespace; +}; + +REGISTER_TRT_PLUGIN_V2(LookupTablePluginDynamicCreator); + +} // namespace plugin +} // namespace tensorrt +} // namespace inference +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu index b89ac08404b4ce10b7d541d3d2d5c6ea555d904d..fd7ab67d0a4d1d898b3743e6afd8a9d17be6587b 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelHFace.cu @@ -21,10 +21,7 @@ #include #include "NvInfer.h" -#include "common/bertCommon.h" -#include "common/common.cuh" -#include "common/plugin.h" -#include "common/serialize.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh" #include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" namespace paddle { diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu index 198d7c57b67afb4eeff7f6fe0c192ee5b6b3aafa..cd69a1ba37a7349a74469b83a62ca0d649fedbf1 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_Layernorm_varseqlen_kernelMTron.cu @@ -21,10 +21,7 @@ #include #include "NvInfer.h" -#include "common/bertCommon.h" -#include "common/common.cuh" -#include "common/plugin.h" -#include "common/serialize.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh" #include "paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h" namespace paddle { diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu index 8ad149bd959122bb5bd75c9a07afb0fa210b5b32..9601f97f7d6bc98f8a24170e0eb77ad4f40ba2e0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.cu @@ -19,7 +19,6 @@ #include #include #include "NvInfer.h" -#include "common/serialize.h" namespace paddle { namespace inference { diff --git a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h index f53652cf3455d9c4ab9b1bda17a3d8176544c190..2886f800a64e0bb404f3c2ea7f75c419be5a7550 100644 --- a/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/many_emb_layernorm_varseqlen_plugin.h @@ -18,7 +18,10 @@ #include #include "NvInferPlugin.h" #include "NvInferRuntime.h" -#include "common/bertCommon.h" + +#include "paddle/fluid/inference/tensorrt/plugin/common/bertCommon.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/plugin.h" +#include "paddle/fluid/inference/tensorrt/plugin/common/serialize.h" #include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h" #include "paddle/fluid/platform/enforce.h" diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py index 8001c76816e652a6ddb5eb6ac1e6a4af48a8f7bc..e8e5e576f1b172ccc7c2684e2db81fe06c8fbec6 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_emb_eltwise_layernorm_fuse_pass.py @@ -228,7 +228,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): max_batch_size=4, workspace_size=102400, min_subgraph_size=0, - precision_mode=paddle_infer.PrecisionType.Float32, + precision_mode=paddle_infer.PrecisionType.Half, use_static=False, use_calib_mode=False) yield config, ['fused_embedding_eltwise_layernorm'], (1e-5, 1e-5) @@ -238,7 +238,7 @@ class TestEmbeddingEltwiseLayerNormFusePass(PassAutoScanTest): max_batch_size=4, workspace_size=102400, min_subgraph_size=0, - precision_mode=paddle_infer.PrecisionType.Float32, + precision_mode=paddle_infer.PrecisionType.Half, use_static=False, use_calib_mode=False) if program_config.ops[0].type == 'lookup_table':