未验证 提交 fda54c02 编写于 作者: P Pei Yang 提交者: GitHub

errmsg refine of trt plugin (#27309)

上级 905e2346
...@@ -25,8 +25,10 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name, ...@@ -25,8 +25,10 @@ PluginTensorRT* PluginFactoryTensorRT::createPlugin(const char* layer_name,
const char* plugin_type; const char* plugin_type;
DeserializeValue(&serial_data, &serial_length, &plugin_type); DeserializeValue(&serial_data, &serial_length, &plugin_type);
PADDLE_ENFORCE(Has(plugin_type), PADDLE_ENFORCE_EQ(
"trt plugin type %s does not exists, check it.", plugin_type); Has(plugin_type), true,
platform::errors::NotFound(
"trt plugin type %s does not exists, check it.", plugin_type));
auto plugin = plugin_registry_[plugin_type](serial_data, serial_length); auto plugin = plugin_registry_[plugin_type](serial_data, serial_length);
owned_plugins_.emplace_back(plugin); owned_plugins_.emplace_back(plugin);
......
...@@ -103,7 +103,12 @@ struct Serializer<std::vector<T>, ...@@ -103,7 +103,12 @@ struct Serializer<std::vector<T>,
DeserializeValue(buffer, buffer_size, &size); DeserializeValue(buffer, buffer_size, &size);
value->resize(size); value->resize(size);
size_t nbyte = value->size() * sizeof(T); size_t nbyte = value->size() * sizeof(T);
PADDLE_ENFORCE_GE(*buffer_size, nbyte); PADDLE_ENFORCE_GE(
*buffer_size, nbyte,
platform::errors::InvalidArgument("Expect buffer size >= value size in "
"trt plugin deserialization, but got "
"buffer size = %d, value size = %d.",
*buffer_size, nbyte));
std::memcpy(value->data(), *buffer, nbyte); std::memcpy(value->data(), *buffer, nbyte);
reinterpret_cast<char const*&>(*buffer) += nbyte; reinterpret_cast<char const*&>(*buffer) += nbyte;
*buffer_size -= nbyte; *buffer_size -= nbyte;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册