From f4d9212de25a7a8c5b5d3d160ed6ce1c4f40bdd0 Mon Sep 17 00:00:00 2001
From: Wilber <jiweibo@baidu.com>
Date: Tue, 23 Mar 2021 15:11:02 +0800
Subject: [PATCH] trt plugin upgrade to pluginv2ext (#31670)

---
 .../inference/tensorrt/convert/split_op.cc    |   2 +-
 paddle/fluid/inference/tensorrt/engine.cc     |   9 +-
 paddle/fluid/inference/tensorrt/engine.h      |   7 ++
 .../inference/tensorrt/plugin/CMakeLists.txt  |   3 +
 .../tensorrt/plugin/split_op_plugin.cu        |   5 -
 .../tensorrt/plugin/split_op_plugin.h         |  69 +++++++++--
 .../tensorrt/plugin/test_split_plugin.cc      |  58 +++++++++
 .../inference/tensorrt/plugin/trt_plugin.cc   |  78 ++++++++++--
 .../inference/tensorrt/plugin/trt_plugin.h    | 112 +++++++++++++++++-
 python/setup.py.in                            |  11 ++
 10 files changed, 322 insertions(+), 32 deletions(-)
 create mode 100644 paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc

diff --git a/paddle/fluid/inference/tensorrt/convert/split_op.cc b/paddle/fluid/inference/tensorrt/convert/split_op.cc
index 768c6efaa6..5d494c2093 100644
--- a/paddle/fluid/inference/tensorrt/convert/split_op.cc
+++ b/paddle/fluid/inference/tensorrt/convert/split_op.cc
@@ -101,7 +101,7 @@ class SplitOpConverter : public OpConverter {
           engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
       plugin::SplitPlugin* plugin =
           new plugin::SplitPlugin(axis, output_lengths, with_fp16);
-      layer = engine_->AddPlugin(&input, input_num, plugin);
+      layer = engine_->AddPluginV2Ext(&input, input_num, plugin);
     }
 
     std::string layer_name = "split (Output: ";
diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc
index 0bba4581ff..99549fd6b5 100644
--- a/paddle/fluid/inference/tensorrt/engine.cc
+++ b/paddle/fluid/inference/tensorrt/engine.cc
@@ -18,7 +18,7 @@ limitations under the License. */
 #include <glog/logging.h>
 #include <string>
 
-#include "cuda_runtime_api.h"
+#include "cuda_runtime_api.h"  // NOLINT
 #include "paddle/fluid/inference/tensorrt/helper.h"
 #include "paddle/fluid/platform/enforce.h"
 #include "paddle/fluid/platform/gpu_info.h"
@@ -353,6 +353,13 @@ nvinfer1::IPluginLayer *TensorRTEngine::AddPlugin(
   return network()->addPluginExt(inputs, num_inputs, *plugin);
 }
 
+nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext(
+    nvinfer1::ITensor *const *inputs, int num_inputs,
+    plugin::PluginTensorRTV2Ext *plugin) {
+  owned_plugin_v2ext_.emplace_back(plugin);
+  return network()->addPluginV2(inputs, num_inputs, *plugin);
+}
+
 void TensorRTEngine::freshDeviceId() {
   int count;
   cudaGetDeviceCount(&count);
diff --git a/paddle/fluid/inference/tensorrt/engine.h b/paddle/fluid/inference/tensorrt/engine.h
index 0e399578fa..de2924824f 100644
--- a/paddle/fluid/inference/tensorrt/engine.h
+++ b/paddle/fluid/inference/tensorrt/engine.h
@@ -305,8 +305,14 @@ class TensorRTEngine {
   }
 
   int GetDeviceId() { return device_id_; }
+
   nvinfer1::IPluginLayer* AddPlugin(nvinfer1::ITensor* const* inputs,
                                     int num_inputs, plugin::PluginTensorRT*);
+
+  nvinfer1::IPluginV2Layer* AddPluginV2Ext(nvinfer1::ITensor* const* inputs,
+                                           int num_inputs,
+                                           plugin::PluginTensorRTV2Ext* plugin);
+
   void SetTensorDynamicRange(nvinfer1::ITensor* tensor, float range) {
     quant_dynamic_range_[tensor] = range;
   }
@@ -414,6 +420,7 @@ class TensorRTEngine {
       itensor_map_;
 
   std::vector<std::unique_ptr<plugin::PluginTensorRT>> owned_plugin_;
+  std::vector<std::unique_ptr<plugin::PluginTensorRTV2Ext>> owned_plugin_v2ext_;
 
   // TensorRT related internal members
   template <typename T>
diff --git a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
index e37beb3b8e..7ee16a598d 100644
--- a/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
+++ b/paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
@@ -6,3 +6,6 @@ nv_library(tensorrt_plugin
            qkv_to_context_plugin.cu skip_layernorm_op_plugin.cu slice_op_plugin.cu
            hard_swish_op_plugin.cu stack_op_plugin.cu special_slice_plugin.cu
            DEPS enforce tensorrt_engine prelu tensor bert_encoder_functor)
+
+nv_test(test_split_plugin SRCS test_split_plugin.cc DEPS
+  paddle_framework ${GLOB_OPERATOR_DEPS} tensorrt_plugin)
diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
index 256aa28206..1b5c39f8ff 100644
--- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
+++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.cu
@@ -22,11 +22,6 @@ namespace inference {
 namespace tensorrt {
 namespace plugin {
 
-SplitPlugin* CreateSplitPluginDeserialize(const void* buffer, size_t length) {
-  return new SplitPlugin(buffer, length);
-}
-REGISTER_TRT_PLUGIN("split_plugin", CreateSplitPluginDeserialize);
-
 template <typename T>
 __device__ int upper_bound(T const* vals, int n, T const& key) {
   int i = 0;
diff --git a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
index 5c47ec3a99..e43b57357f 100644
--- a/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
+++ b/paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h
@@ -25,7 +25,7 @@ namespace inference {
 namespace tensorrt {
 namespace plugin {
 
-class SplitPlugin : public PluginTensorRT {
+class SplitPlugin : public PluginTensorRTV2Ext {
  public:
   SplitPlugin() {}
   SplitPlugin(int axis, std::vector<int> const& output_lengths, bool with_fp16)
@@ -39,13 +39,20 @@ class SplitPlugin : public PluginTensorRT {
     DeserializeValue(&serial_data, &serial_length, &output_length_);
   }
 
-  SplitPlugin* clone() const override {
-    auto* ptr = new SplitPlugin(axis_, output_length_, with_fp16_);
+  nvinfer1::IPluginV2Ext* clone() const override {
+    SplitPlugin* ptr = new SplitPlugin(axis_, output_length_, with_fp16_);
+    ptr->setPluginNamespace(this->getPluginNamespace());
     ptr->shareData(this);
     return ptr;
   }
 
-  const char* getPluginType() const override { return "split_plugin"; }
+  nvinfer1::DataType getOutputDataType(int index,
+                                       const nvinfer1::DataType* input_types,
+                                       int nb_inputs) const override {
+    return input_types[0];
+  }
+
+  const char* getPluginType() const override { return "split_plugin_v2ext"; }
   int getNbOutputs() const override { return output_length_.size(); }
   nvinfer1::Dims getOutputDimensions(int index,
                                      const nvinfer1::Dims* input_dims,
@@ -53,17 +60,18 @@ class SplitPlugin : public PluginTensorRT {
 
   int initialize() override;
   void terminate() override;
-  int enqueue(int batchSize, const void* const* inputs, void** outputs,
+  int enqueue(int batch_size, const void* const* inputs, void** outputs,
               void* workspace, cudaStream_t stream) override;
 
+  void destroy() override { delete this; }
+
  protected:
-  size_t getSerializationSize() override {
-    return SerializedSize(getPluginType()) + SerializedSize(axis_) +
-           SerializedSize(output_length_) + getBaseSerializationSize();
+  size_t getSerializationSize() const override {
+    return SerializedSize(axis_) + SerializedSize(output_length_) +
+           getBaseSerializationSize();
   }
 
-  void serialize(void* buffer) override {
-    SerializeValue(&buffer, getPluginType());
+  void serialize(void* buffer) const override {
     serializeBase(buffer);
     SerializeValue(&buffer, axis_);
     SerializeValue(&buffer, output_length_);
@@ -83,6 +91,47 @@ class SplitPlugin : public PluginTensorRT {
   void shareData(const SplitPlugin* another);
 };
 
+class SplitPluginCreator : public nvinfer1::IPluginCreator {
+ public:
+  SplitPluginCreator() {}
+  const char* getPluginName() const override { return "split_plugin_v2ext"; }
+
+  const char* getPluginVersion() const override { return "1"; }
+
+  const nvinfer1::PluginFieldCollection* getFieldNames() override {
+    return &field_collection_;
+  }
+
+  nvinfer1::IPluginV2* createPlugin(
+      const char* name, const nvinfer1::PluginFieldCollection* fc) override {
+    // not implemented
+    return nullptr;
+  }
+
+  nvinfer1::IPluginV2* deserializePlugin(const char* name,
+                                         const void* serial_data,
+                                         size_t serial_length) override {
+    auto plugin = new SplitPlugin(serial_data, serial_length);
+    return plugin;
+  }
+
+  void setPluginNamespace(const char* lib_namespace) override {
+    plugin_namespace_ = lib_namespace;
+  }
+
+  const char* getPluginNamespace() const override {
+    return plugin_namespace_.c_str();
+  }
+
+ private:
+  std::string plugin_namespace_;
+  std::string plugin_name_;
+  nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
+  std::vector<nvinfer1::PluginField> plugin_attributes_;
+};
+
+REGISTER_TRT_PLUGIN_V2(SplitPluginCreator);
+
 #if IS_TRT_VERSION_GE(6000)
 class SplitPluginDynamic : public DynamicPluginTensorRT {
  public:
diff --git a/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc
new file mode 100644
index 0000000000..6636513a55
--- /dev/null
+++ b/paddle/fluid/inference/tensorrt/plugin/test_split_plugin.cc
@@ -0,0 +1,58 @@
+/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include <gtest/gtest.h>
+#include "paddle/fluid/inference/tensorrt/plugin/split_op_plugin.h"
+
+namespace paddle {
+namespace inference {
+namespace tensorrt {
+namespace plugin {
+
+TEST(split_op_plugin, test_plugin) {
+  int axis = 1;
+  std::vector<int> output_lengths{1, 1};
+  bool with_fp16 = false;
+  std::vector<nvinfer1::DataType> input_types{nvinfer1::DataType::kFLOAT};
+  std::vector<nvinfer1::Dims> input_dims;
+
+  SplitPlugin sp_plugin(axis, output_lengths, with_fp16);
+  nvinfer1::Dims in_dims;
+  in_dims.nbDims = 4;
+  input_dims.push_back(in_dims);
+  sp_plugin.configurePlugin(input_dims.data(), 1, nullptr, 2,
+                            input_types.data(), nullptr, nullptr, nullptr,
+                            nvinfer1::PluginFormat::kNCHW, 4);
+  sp_plugin.initialize();
+  sp_plugin.getPluginType();
+  sp_plugin.canBroadcastInputAcrossBatch(0);
+  sp_plugin.getNbOutputs();
+  auto clone_plugin = sp_plugin.clone();
+  clone_plugin->setPluginNamespace("test");
+  clone_plugin->destroy();
+  sp_plugin.getOutputDataType(0, input_types.data(), 1);
+  sp_plugin.terminate();
+}
+
+TEST(split_op_plugin, test_plugin_creater) {
+  SplitPluginCreator creator;
+  creator.getFieldNames();
+  creator.createPlugin("test", nullptr);
+  creator.setPluginNamespace("test");
+}
+
+}  // namespace plugin
+}  // namespace tensorrt
+}  // namespace inference
+}  // namespace paddle
diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
index fd721b1614..55bc786746 100644
--- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
+++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.cc
@@ -19,27 +19,50 @@ namespace inference {
 namespace tensorrt {
 namespace plugin {
 
+inline void Seria(void*& buffer,  // NOLINT
+                  const std::vector<nvinfer1::Dims>& input_dims,
+                  size_t max_batch_size, nvinfer1::DataType data_type,
+                  nvinfer1::PluginFormat data_format, bool with_fp16) {
+  SerializeValue(&buffer, input_dims);
+  SerializeValue(&buffer, max_batch_size);
+  SerializeValue(&buffer, data_type);
+  SerializeValue(&buffer, data_format);
+  SerializeValue(&buffer, with_fp16);
+}
+
+inline void Deseria(void const*& serial_data, size_t& serial_length,  // NOLINT
+                    std::vector<nvinfer1::Dims>* input_dims,
+                    size_t* max_batch_size, nvinfer1::DataType* data_type,
+                    nvinfer1::PluginFormat* data_format, bool* with_fp16) {
+  DeserializeValue(&serial_data, &serial_length, input_dims);
+  DeserializeValue(&serial_data, &serial_length, max_batch_size);
+  DeserializeValue(&serial_data, &serial_length, data_type);
+  DeserializeValue(&serial_data, &serial_length, data_format);
+  DeserializeValue(&serial_data, &serial_length, with_fp16);
+}
+
+inline size_t SeriaSize(const std::vector<nvinfer1::Dims>& input_dims,
+                        size_t max_batch_size, nvinfer1::DataType data_type,
+                        nvinfer1::PluginFormat data_format, bool with_fp16) {
+  return (SerializedSize(input_dims) + SerializedSize(max_batch_size) +
+          SerializedSize(data_type) + SerializedSize(data_format) +
+          SerializedSize(with_fp16));
+}
+
 void PluginTensorRT::serializeBase(void*& buffer) {
-  SerializeValue(&buffer, input_dims_);
-  SerializeValue(&buffer, max_batch_size_);
-  SerializeValue(&buffer, data_type_);
-  SerializeValue(&buffer, data_format_);
-  SerializeValue(&buffer, with_fp16_);
+  Seria(buffer, input_dims_, max_batch_size_, data_type_, data_format_,
+        with_fp16_);
 }
 
 void PluginTensorRT::deserializeBase(void const*& serial_data,
                                      size_t& serial_length) {
-  DeserializeValue(&serial_data, &serial_length, &input_dims_);
-  DeserializeValue(&serial_data, &serial_length, &max_batch_size_);
-  DeserializeValue(&serial_data, &serial_length, &data_type_);
-  DeserializeValue(&serial_data, &serial_length, &data_format_);
-  DeserializeValue(&serial_data, &serial_length, &with_fp16_);
+  Deseria(serial_data, serial_length, &input_dims_, &max_batch_size_,
+          &data_type_, &data_format_, &with_fp16_);
 }
 
 size_t PluginTensorRT::getBaseSerializationSize() {
-  return (SerializedSize(input_dims_) + SerializedSize(max_batch_size_) +
-          SerializedSize(data_type_) + SerializedSize(data_format_) +
-          SerializedSize(with_fp16_));
+  return SeriaSize(input_dims_, max_batch_size_, data_type_, data_format_,
+                   with_fp16_);
 }
 
 bool PluginTensorRT::supportsFormat(nvinfer1::DataType type,
@@ -58,6 +81,35 @@ void PluginTensorRT::configureWithFormat(
   max_batch_size_ = max_batch_size;
 }
 
+void PluginTensorRTV2Ext::serializeBase(void*& buffer) const {
+  Seria(buffer, input_dims_, max_batch_size_, data_type_, data_format_,
+        with_fp16_);
+}
+
+void PluginTensorRTV2Ext::deserializeBase(void const*& serial_data,
+                                          size_t& serial_length) {
+  Deseria(serial_data, serial_length, &input_dims_, &max_batch_size_,
+          &data_type_, &data_format_, &with_fp16_);
+}
+
+size_t PluginTensorRTV2Ext::getBaseSerializationSize() const {
+  return SeriaSize(input_dims_, max_batch_size_, data_type_, data_format_,
+                   with_fp16_);
+}
+
+void PluginTensorRTV2Ext::configurePlugin(
+    const nvinfer1::Dims* input_dims, int32_t nb_inputs,
+    const nvinfer1::Dims* output_dims, int32_t nb_outputs,
+    const nvinfer1::DataType* input_types,
+    const nvinfer1::DataType* output_types, const bool* input_is_broadcast,
+    const bool* output_is_broadcast, nvinfer1::PluginFormat float_format,
+    int32_t max_batch_size) {
+  input_dims_.assign(input_dims, input_dims + nb_inputs);
+  max_batch_size_ = max_batch_size;
+  data_format_ = float_format;
+  data_type_ = input_types[0];
+}
+
 }  // namespace plugin
 }  // namespace tensorrt
 }  // namespace inference
diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
index b3a3abe5d0..ce3133ae99 100644
--- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
+++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin.h
@@ -44,6 +44,7 @@ typedef std::function<PluginTensorRT*(const void*, size_t)>
 
 typedef std::function<PluginTensorRT*(void)> PluginConstructFunc;
 
+// Deprecated. Do not inherit this class, please refer to PluginTensorRTV2Ext
 class PluginTensorRT : public nvinfer1::IPluginExt {
  public:
   PluginTensorRT() : with_fp16_(false) {}
@@ -119,6 +120,114 @@ class PluginTensorRT : public nvinfer1::IPluginExt {
   bool with_fp16_;
 };
 
+// TensorRT introduced IPluginV2Ext after 5.1, Paddle no longer supports
+// versions before 5.1
+class PluginTensorRTV2Ext : public nvinfer1::IPluginV2Ext {
+ public:
+  PluginTensorRTV2Ext() : with_fp16_(false) {}
+  PluginTensorRTV2Ext(const void* serialized_data, size_t length) {}
+
+  nvinfer1::Dims const& getInputDims(int index) const {
+    return input_dims_.at(index);
+  }
+  size_t getMaxBatchSize() const { return max_batch_size_; }
+  nvinfer1::DataType getDataType() const { return data_type_; }
+  nvinfer1::PluginFormat getDataFormat() const { return data_format_; }
+
+  // The Func in IPluginV2Ext
+  virtual nvinfer1::DataType getOutputDataType(
+      int index, const nvinfer1::DataType* input_types,
+      int nb_inputs) const = 0;
+
+  virtual bool isOutputBroadcastAcrossBatch(int32_t output_index,
+                                            const bool* input_is_broadcasted,
+                                            int32_t nb_inputs) const {
+    return false;
+  }
+
+  virtual bool canBroadcastInputAcrossBatch(int32_t input_index) const {
+    return false;
+  }
+
+  void configurePlugin(const nvinfer1::Dims* input_dims, int32_t nb_inputs,
+                       const nvinfer1::Dims* output_dims, int32_t nb_outputs,
+                       const nvinfer1::DataType* input_types,
+                       const nvinfer1::DataType* output_types,
+                       const bool* input_is_broadcast,
+                       const bool* output_is_broadcast,
+                       nvinfer1::PluginFormat float_format,
+                       int32_t max_batch_size) override;
+
+  virtual IPluginV2Ext* clone() const = 0;
+
+  void attachToContext(cudnnContext*, cublasContext*,
+                       nvinfer1::IGpuAllocator*) override {}
+
+  void detachFromContext() override {}
+
+  // The Func in IPluginV2
+  virtual const char* getPluginType() const = 0;
+  const char* getPluginVersion() const override { return "1"; }
+  virtual int32_t getNbOutputs() const { return 1; }
+  virtual nvinfer1::Dims getOutputDimensions(int32_t index,
+                                             const nvinfer1::Dims* inputs,
+                                             int32_t nb_input) = 0;
+  // Check format support. The default is FLOAT32 and NCHW.
+  bool supportsFormat(nvinfer1::DataType type,
+                      nvinfer1::PluginFormat format) const override {
+    return ((type == nvinfer1::DataType::kFLOAT) &&
+            (format == nvinfer1::PluginFormat::kNCHW));
+  }
+  // Initialize the layer for execution.
+  // This is called when the engine is created.
+  int initialize() override { return 0; }
+
+  // Shutdown the layer. This is called when the engine is destroyed
+  void terminate() override {}
+
+  // Find the workspace size required by the layer
+  size_t getWorkspaceSize(int) const override { return 0; }
+
+  // Execute the layer
+  virtual int enqueue(int batch_size, const void* const* inputs, void** outputs,
+                      void* workspace, cudaStream_t stream) = 0;
+
+  // Find the size of the serialization buffer required
+  virtual size_t getSerializationSize() const = 0;
+
+  // Serialize the layer config to buffer.
+  // TensorRT will call this func to serialize the configuration of TensorRT
+  // engine. It should not be called by users.
+  virtual void serialize(void* buffer) const = 0;
+
+  virtual void destroy() = 0;
+
+  void setPluginNamespace(const char* plugin_namespace) override {
+    name_space_ = plugin_namespace;
+  }
+
+  const char* getPluginNamespace() const override {
+    return name_space_.c_str();
+  }
+
+ protected:
+  void deserializeBase(void const*& serial_data,  // NOLINT
+                       size_t& serial_length);    // NOLINT
+  size_t getBaseSerializationSize() const;
+  void serializeBase(void*& buffer) const;  // NOLINT
+
+ protected:
+  std::vector<nvinfer1::Dims> input_dims_;
+  size_t max_batch_size_;
+  nvinfer1::DataType data_type_;
+  nvinfer1::PluginFormat data_format_;
+  std::vector<nvinfer1::ITensor*> inputs_;
+  bool with_fp16_;
+
+ private:
+  std::string name_space_;
+};
+
 #if IS_TRT_VERSION_GE(6000)
 class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
  public:
@@ -184,6 +293,7 @@ class DynamicPluginTensorRT : public nvinfer1::IPluginV2DynamicExt {
   std::string name_space_;
   std::string plugin_base_;
 };
+#endif
 
 template <typename T>
 class TrtPluginRegistrarV2 {
@@ -203,8 +313,6 @@ class TrtPluginRegistrarV2 {
   static paddle::inference::tensorrt::plugin::TrtPluginRegistrarV2<name> \
       plugin_registrar_##name {}
 
-#endif
-
 }  // namespace plugin
 }  // namespace tensorrt
 }  // namespace inference
diff --git a/python/setup.py.in b/python/setup.py.in
index 64cfe6e9cc..69a8bc771a 100644
--- a/python/setup.py.in
+++ b/python/setup.py.in
@@ -336,6 +336,17 @@ if '${WITH_XPU_BKCL}' == 'ON':
     shutil.copy('${XPU_BKCL_LIB}', libs_path)
     package_data['paddle.libs']+=['${XPU_BKCL_LIB_NAME}']
 
+# Only for lite xpu inference.
+if '${WITH_XPU}' == 'OFF' and '${XPU_SDK_ROOT}' != '':
+    xpu_api_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/shlib/', 'libxpuapi.so')
+    xpu_rt_lib = os.path.join('${XPU_SDK_ROOT}', 'XTDK/runtime/shlib/', 'libxpurt.so')
+    if os.path.exists(xpu_api_lib):
+        shutil.copy(xpu_api_lib, libs_path)
+        package_data['paddle.libs']+=['libxpuapi.so']
+    if os.path.exists(xpu_rt_lib):
+        shutil.copy(xpu_rt_lib, libs_path)
+        package_data['paddle.libs']+=['libxpurt.so']
+
 ### Old custom op extension mechanism related, will be removed in 2.1.0 ###
 # copy libpaddle_framework.so to libs on linux
 if sys.platform.startswith('linux'):
-- 
GitLab