diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc index 3c20b6d1e725273dbfdc20c01fb01deea4e8d88e..76b0832c546b92068364ba6b2eda65a04742e5f0 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_factory.cc @@ -25,8 +25,10 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, const char* plugin_type; DeserializeValue(&serial_data, &serial_length, &plugin_type); - PADDLE_ENFORCE(Has(plugin_type), - "trt plugin type %s does not exists, check it.", plugin_type); + PADDLE_ENFORCE_EQ( + Has(plugin_type), true, + platform::errors::NotFound( + "trt plugin type %s does not exists, check it.", plugin_type)); auto plugin = plugin_registry_[plugin_type](serial_data, serial_length); owned_plugins_.emplace_back(plugin); diff --git a/paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h index 18037179c7b98952b6088361954e869ecedfb2c7..6fcb70c6d3299f830e1e95e328b2645aedf9cc31 100644 --- a/paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h +++ b/paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h @@ -103,7 +103,12 @@ struct Serializer, DeserializeValue(buffer, buffer_size, &size); value->resize(size); size_t nbyte = value->size() * sizeof(T); - PADDLE_ENFORCE_GE(*buffer_size, nbyte); + PADDLE_ENFORCE_GE( + *buffer_size, nbyte, + platform::errors::InvalidArgument("Expect buffer size >= value size in " + "trt plugin deserialization, but got " + "buffer size = %d, value size = %d.", + *buffer_size, nbyte)); std::memcpy(value->data(), *buffer, nbyte); reinterpret_cast(*buffer) += nbyte; *buffer_size -= nbyte;