未验证 提交 50ac7dbf 编写于 作者: S Shang Zhizhou 提交者: GitHub

Trt elementwise plugin serialize (#31587)

* add serialize unittest

* fix element_op trt plugin serialize bug
上级 ef0dd3ef
...@@ -152,9 +152,14 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -152,9 +152,14 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs,
int ElementwisePluginDynamic::initialize() { return 0; } int ElementwisePluginDynamic::initialize() { return 0; }
size_t ElementwisePluginDynamic::getSerializationSize() const { return 0; } size_t ElementwisePluginDynamic::getSerializationSize() const {
return SerializedSize(type_.c_str()) + SerializedSize(axis_);
}
void ElementwisePluginDynamic::serialize(void *buffer) const {} void ElementwisePluginDynamic::serialize(void *buffer) const {
SerializeValue(&buffer, type_.c_str());
SerializeValue(&buffer, axis_);
}
nvinfer1::DimsExprs ElementwisePluginDynamic::getOutputDimensions( nvinfer1::DimsExprs ElementwisePluginDynamic::getOutputDimensions(
int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs,
......
...@@ -92,7 +92,12 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT { ...@@ -92,7 +92,12 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
public: public:
explicit ElementwisePluginDynamic(const std::string& type, int axis) explicit ElementwisePluginDynamic(const std::string& type, int axis)
: type_(type), axis_(axis) {} : type_(type), axis_(axis) {}
ElementwisePluginDynamic(void const* serialData, size_t serialLength) {} ElementwisePluginDynamic(void const* serialData, size_t serialLength) {
const char* elementwise_type;
DeserializeValue(&serialData, &serialLength, &elementwise_type);
type_ = std::string(elementwise_type);
DeserializeValue(&serialData, &serialLength, &axis_);
}
nvinfer1::IPluginV2DynamicExt* clone() const override { nvinfer1::IPluginV2DynamicExt* clone() const override {
return new ElementwisePluginDynamic(type_, axis_); return new ElementwisePluginDynamic(type_, axis_);
} }
...@@ -138,6 +143,46 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT { ...@@ -138,6 +143,46 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT {
std::string type_; std::string type_;
int axis_; int axis_;
}; };
class ElementwisePluginV2Creator : public nvinfer1::IPluginCreator {
public:
ElementwisePluginV2Creator() {}
const char* getPluginName() const override { return "elementwise_plugin"; }
const char* getPluginVersion() const override { return "1"; }
const nvinfer1::PluginFieldCollection* getFieldNames() override {
return &field_collection_;
}
nvinfer1::IPluginV2* createPlugin(
const char* name, const nvinfer1::PluginFieldCollection* fc) override {
return nullptr;
}
nvinfer1::IPluginV2* deserializePlugin(const char* name,
const void* serial_data,
size_t serial_length) override {
auto plugin = new ElementwisePluginDynamic(serial_data, serial_length);
return plugin;
}
void setPluginNamespace(const char* lib_namespace) override {
plugin_namespace_ = lib_namespace;
}
const char* getPluginNamespace() const override {
return plugin_namespace_.c_str();
}
private:
std::string plugin_namespace_;
std::string plugin_name_;
nvinfer1::PluginFieldCollection field_collection_{0, nullptr};
std::vector<nvinfer1::PluginField> plugin_attributes_;
};
REGISTER_TRT_PLUGIN_V2(ElementwisePluginV2Creator);
#endif #endif
} // namespace plugin } // namespace plugin
......
...@@ -414,6 +414,58 @@ class TensorRTSubgraphPassElementwiseMulTest( ...@@ -414,6 +414,58 @@ class TensorRTSubgraphPassElementwiseMulTest(
return fluid.layers.elementwise_mul(x=data1, y=data2) return fluid.layers.elementwise_mul(x=data1, y=data2)
class TensorRTSubgraphPassElementwiseSerializeTest(
TensorRTSubgraphPassElementwiseTest):
def setUp(self):
super(TensorRTSubgraphPassElementwiseSerializeTest, self).setUp()
self.trt_parameters = TensorRTSubgraphPassElementwiseTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False)
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
super(TensorRTSubgraphPassElementwiseSerializeTest,
self).test_check_output()
class TensorRTSubgraphPassElementwiseBroadcastDynamicTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data1 = fluid.data(
name="data1", shape=[-1, 3, 64, 64], dtype="float32")
data2 = fluid.data(name="data2", shape=[64, 64], dtype="float32")
eltwise_out = self.append_eltwise(data1, data2)
out = fluid.layers.batch_norm(eltwise_out, is_test=True)
self.feeds = {
"data1": np.random.random([1, 3, 64, 64]).astype("float32"),
"data2": np.random.random([64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassElementwiseBroadcastDynamicTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False)
self.dynamic_shape_params = TensorRTSubgraphPassElementwiseBroadcastDynamicTest.DynamicShapeParam(
{
'data1': [1, 3, 8, 64],
'data2': [8, 64]
}, {'data1': [1, 3, 512, 64],
'data2':
[512, 64]}, {'data1': [1, 3, 256, 64],
'data2': [256, 64]}, False)
self.fetch_list = [out]
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_add(x=data1, y=data2)
def test_check_output(self):
if os.path.exists(self.path + "_opt_cache"):
shutil.rmtree(self.path + "_opt_cache")
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassShuffleChannelTest(InferencePassTest): class TensorRTSubgraphPassShuffleChannelTest(InferencePassTest):
def setUp(self): def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册