From ae6e40a7fdc7381a21e189087f945ab0c924c342 Mon Sep 17 00:00:00 2001 From: Pei Yang Date: Mon, 28 Sep 2020 14:06:00 +0800 Subject: [PATCH] Add unittests and OP version registry for tensorrt_subgraph_pass (#27544) * add unittests and op version register for tensorrt_subgraph_pass * rename to test_trt_subgraph_pass.py * fix softmax converter diff when padding dim=1 --- .../ir_passes/tensorrt_subgraph_pass.cc | 29 + .../inference/tensorrt/convert/softmax_op.cc | 35 +- paddle/fluid/inference/tensorrt/op_teller.cc | 9 +- .../tensorrt/plugin/layer_norm_op_plugin.cu | 24 +- .../ir/inference/inference_pass_test.py | 14 +- .../ir/inference/test_trt_subgraph_pass.py | 546 ++++++++++++++++++ 6 files changed, 642 insertions(+), 15 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index 46612c1c5b..1d4725ddab 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -18,6 +18,7 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/subgraph_detector.h" +#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" @@ -358,3 +359,31 @@ REGISTER_PASS(tensorrt_subgraph_pass, .RequirePassAttr("max_batch_size") .RequirePassAttr("workspace_size") .RequirePassAttr("min_subgraph_size"); + +REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("conv2d", 0) + .EQ("pool2d", 0) + .EQ("relu", 0) + .EQ("softmax", 0) + .EQ("sigmoid", 0) + .EQ("hard_swish", 0) + .EQ("depthwise_conv2d", 0) + .EQ("batch_norm", 0) + .EQ("concat", 0) + .EQ("tanh", 0) + .EQ("pad", 0) + .EQ("elementwise_add", 0) + .EQ("elementwise_mul", 0) + .EQ("prelu", 0) + .LE("conv2d_transpose", 1) + .LE("leaky_relu", 1) + .EQ("fc", 0) + .EQ("shuffle_channel", 0) + .EQ("swish", 0) + .EQ("split", 0) + .EQ("instance_norm", 0) + .EQ("gelu", 0) + .EQ("layer_norm", 0) + .EQ("scale", 0)); diff --git a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc index 0388154427..79992065a2 100644 --- a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" namespace paddle { @@ -39,9 +40,41 @@ class SoftMaxOpConverter : public OpConverter { framework::OpDesc op_desc(op, nullptr); // Declare inputs auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); + nvinfer1::Dims input_shape = input1->getDimensions(); + int input_dims = input_shape.nbDims; + int axis = op_desc.HasAttr("axis") + ? BOOST_GET_CONST(int, op_desc.GetAttr("axis")) + : -1; + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax, *const_cast(input1)); - + uint32_t axes = std::max(0, input_dims - 3); + // TODO(cryoco): Poor workaround. Fix padded dims problem when TRT layers + // support Nd. + int padded_dims = 0; + int explicit_batch = 0; + if (engine_->with_dynamic_shape()) explicit_batch = 1; + for (int i = input_dims - 1; i > explicit_batch; i--) { + if (input_shape.d[i] == 1) { + padded_dims += 1; + } else { + break; + } + } + if (!engine_->with_dynamic_shape()) { + if (axis == -1) { + axes = input_dims - 1 - padded_dims; + } else { + axes = axis; + } + } else { + if (axis == -1) { + axes = input_dims - 1 - padded_dims; + } else { + axes = axis + 1; + } + } + layer->setAxes(1 << axes); auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 23aacedd69..21ca678397 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -113,7 +113,14 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") { std::vector paddings = BOOST_GET_CONST(std::vector, desc.GetAttr("paddings")); - if (paddings.size() > 2) return false; + + std::string padding_algorithm = "EXPLICIT"; + if (desc.HasAttr("padding_algorithm")) + padding_algorithm = + BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm")); + if (paddings.size() > 2 || + (padding_algorithm == "SAME" && op_type != "pool2d")) + return false; } if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } diff --git a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu index 7c905a245a..8af036a0e8 100644 --- a/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.cu @@ -50,10 +50,18 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, float *output = reinterpret_cast(outputs)[0]; int begin_norm_axis = begin_norm_axis_; float eps = eps_; - int c = input_dims.d[begin_norm_axis - 1]; - scale_t.Resize(framework::make_ddim({c})); - bias_t.Resize(framework::make_ddim({c})); + std::vector input_shape; + input_shape.push_back(batch_size); + for (int i = 0; i < input_dims.nbDims; i++) { + input_shape.push_back(input_dims.d[i]); + } + const auto input_ddim = framework::make_ddim(input_shape); + auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis - 1); + int feature_size = static_cast(matrix_dim[1]); + + scale_t.Resize(framework::make_ddim({feature_size})); + bias_t.Resize(framework::make_ddim({feature_size})); mean_t.Resize(framework::make_ddim(mean_shape_)); variance_t.Resize(framework::make_ddim(variance_shape_)); int device_id; @@ -63,15 +71,11 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, float *mean_d = mean_t.mutable_data(platform::CUDAPlace(device_id)); float *variance_d = variance_t.mutable_data(platform::CUDAPlace(device_id)); - cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * c, + cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * feature_size, cudaMemcpyHostToDevice, stream); - cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * c, + cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size, cudaMemcpyHostToDevice, stream); - std::vector input_shape; - input_shape.push_back(batch_size); - for (int i = 0; i < input_dims.nbDims; i++) { - input_shape.push_back(input_dims.d[i]); - } + paddle::operators::LayerNormDirectCUDAFunctor layer_norm; layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, variance_d, begin_norm_axis, eps); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py index d3a53bbbff..2af86dfd3c 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py @@ -133,7 +133,7 @@ class InferencePassTest(unittest.TestCase): for place_ in use_gpu: self.check_output_with_option(place_, atol) - def check_output_with_option(self, use_gpu, atol=1e-5): + def check_output_with_option(self, use_gpu, atol=1e-5, flatten=False): ''' Check whether calculating on CPU and GPU, enable TensorRT or disable TensorRT, enable MKLDNN or disable MKLDNN @@ -155,9 +155,13 @@ class InferencePassTest(unittest.TestCase): format(device)) for out, analysis_output in zip(outs, analysis_outputs): + out = np.array(out) + if flatten: + out = out.flatten() + analysis_output = analysis_output.flatten() self.assertTrue( np.allclose( - np.array(out), analysis_output, atol=atol), + out, analysis_output, atol=atol), "Output has diff between inference and training forward at {} ". format(device)) @@ -172,9 +176,13 @@ class InferencePassTest(unittest.TestCase): "The number of outputs is different between GPU and TensorRT. ") for out, tensorrt_output in zip(outs, tensorrt_outputs): + out = np.array(out) + if flatten: + out = out.flatten() + tensorrt_output = tensorrt_output.flatten() self.assertTrue( np.allclose( - np.array(out), tensorrt_output, atol=atol), + out, tensorrt_output, atol=atol), "Output has diff between GPU and TensorRT. ") # Check whether the mkldnn results and the CPU results are the same. diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py new file mode 100644 index 0000000000..c651f69a55 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py @@ -0,0 +1,546 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import PassVersionChecker +from paddle.fluid.core import AnalysisConfig + + +class TensorRTSubgraphPassConvTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 64, 64], dtype="float32") + conv_out = fluid.layers.conv2d( + input=data, + num_filters=self.conv_num_filters, + filter_size=self.conv_filter_size, + groups=self.conv_groups, + padding=self.conv_padding, + bias_attr=False, + act=None) + self.feeds = { + "data": np.random.random([1, 6, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassConvTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [conv_out] + + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 3 + self.conv_padding = [1, 1] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 3 + self.conv_padding = 'VALID' + + +''' +# conv2d padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete. +class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 3 + self.conv_padding = 'SAME' +''' + + +class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 6 + self.conv_padding = [1, 1] + + +class TensorRTSubgraphPassConvTransposeTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 64, 64], dtype="float32") + conv_out = fluid.layers.conv2d_transpose( + input=data, + num_filters=self.conv_num_filters, + filter_size=self.conv_filter_size, + groups=self.conv_groups, + padding=self.conv_padding, + bias_attr=False, + act=None) + self.feeds = { + "data": np.random.random([1, 6, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassConvTransposeTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [conv_out] + + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 1 + self.conv_padding = [1, 1] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassConvTransposeValidPaddingTest( + TensorRTSubgraphPassConvTransposeTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 1 + self.conv_padding = 'VALID' + + +''' +# conv2d_transpose padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete. +class TensorRTSubgraphPassConvTransposeSamePaddingTest(TensorRTSubgraphPassConvTransposeTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 1 + self.conv_padding = 'SAME' +''' + + +class TensorRTSubgraphPassDepthwiseConvTransposeTest( + TensorRTSubgraphPassConvTransposeTest): + def set_params(self): + self.conv_num_filters = 6 + self.conv_filter_size = 6 + self.conv_groups = 1 + self.conv_padding = [1, 1] + + +class TensorRTSubgraphPassFcTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 64, 64], dtype="float32") + fc_out = fluid.layers.fc(input=[data], act=None, size=1000) + reshape_out = fluid.layers.reshape(x=fc_out, shape=[1, 1000]) + self.feeds = { + "data": np.random.random([1, 6, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassFcTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [reshape_out] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + # TRT output shape of fc is (1, 1000, 1, 1). To compare the output value only, flatten the results. + self.check_output_with_option(use_gpu, flatten=True) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassPoolTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 64, 64], dtype="float32") + pool_out = fluid.layers.pool2d( + input=data, + pool_size=self.pool_size, + pool_type=self.pool_type, + pool_stride=self.pool_stride, + pool_padding=self.pool_padding, + global_pooling=self.global_pooling, + ceil_mode=self.ceil_mode, + exclusive=self.exclusive) + out = fluid.layers.batch_norm(pool_out, is_test=True) + self.feeds = { + "data": np.random.random([1, 6, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassPoolTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def set_params(self): + self.pool_size = 2 + self.pool_type = 'max' + self.pool_stride = 1 + self.pool_padding = 0 + self.global_pooling = False + self.ceil_mode = False + self.exclusive = False + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassAvgPoolTest(TensorRTSubgraphPassPoolTest): + def set_params(self): + self.pool_size = 2 + self.pool_type = 'avg' + self.pool_stride = 1 + self.pool_padding = 0 + self.global_pooling = False + self.ceil_mode = False + self.exclusive = False + + +class TensorRTSubgraphPassGlobalPoolTest(TensorRTSubgraphPassPoolTest): + def set_params(self): + self.pool_size = 2 + self.pool_type = 'max' + self.pool_stride = 1 + self.pool_padding = 0 + self.global_pooling = True + self.ceil_mode = False + self.exclusive = False + + +class TensorRTSubgraphPassCeilPoolTest(TensorRTSubgraphPassPoolTest): + def set_params(self): + self.pool_size = 2 + self.pool_type = 'max' + self.pool_stride = 1 + self.pool_padding = 0 + self.global_pooling = False + self.ceil_mode = True + self.exclusive = False + + +class TensorRTSubgraphPassExclusivePoolTest(TensorRTSubgraphPassPoolTest): + def set_params(self): + self.pool_size = 2 + self.pool_type = 'max' + self.pool_stride = 1 + self.pool_padding = 0 + self.global_pooling = False + self.ceil_mode = False + self.exclusive = True + + +class TensorRTSubgraphPassSamePaddingPoolTest(InferencePassTest): + def set_params(self): + self.pool_size = 2 + self.pool_type = 'max' + self.pool_stride = 1 + self.pool_padding = 'SAME' + self.global_pooling = False + self.ceil_mode = False + self.exclusive = False + + +class TensorRTSubgraphPassValidPaddingPoolTest(InferencePassTest): + def set_params(self): + self.pool_size = 2 + self.pool_type = 'max' + self.pool_stride = 1 + self.pool_padding = 'VALID' + self.global_pooling = False + self.ceil_mode = False + self.exclusive = False + + +class TensorRTSubgraphPassActivationTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 64, 64], dtype="float32") + act_out = self.append_act(data) + out = fluid.layers.batch_norm(act_out, is_test=True) + self.feeds = { + "data": np.random.random([1, 6, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def append_act(self, x): + return fluid.layers.relu(x) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassLeakyReluTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.leaky_relu(x) + + +class TensorRTSubgraphPassRelu6Test(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.relu6(x) + + +class TensorRTSubgraphPassSoftMaxTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.softmax(x) + + +class TensorRTSubgraphPassSigmoidTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.sigmoid(x) + + +class TensorRTSubgraphPassHardSwishTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.hard_swish(x) + + +class TensorRTSubgraphPassHardSigmoidTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.hard_sigmoid(x) + + +class TensorRTSubgraphPassTanhTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.tanh(x) + + +class TensorRTSubgraphPassSwishTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.swish(x) + + +class TensorRTSubgraphPassPreluAllTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.prelu(x, mode='all') + + +class TensorRTSubgraphPassPreluChannelTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.prelu(x, mode='channel') + + +class TensorRTSubgraphPassPreluElementTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.prelu(x, mode='element') + + +class TensorRTSubgraphPassGeluTest(TensorRTSubgraphPassActivationTest): + def append_act(self, x): + return fluid.layers.gelu(x) + + +class TensorRTSubgraphPassConcatTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data1 = fluid.data( + name="data1", shape=[-1, 3, 64, 64], dtype="float32") + data2 = fluid.data( + name="data2", shape=[-1, 3, 64, 64], dtype="float32") + concat_out = fluid.layers.concat([data1, data2], axis=2) + out = fluid.layers.batch_norm(concat_out, is_test=True) + self.feeds = { + "data1": np.random.random([1, 3, 64, 64]).astype("float32"), + "data2": np.random.random([1, 3, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassConcatTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassSplitTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 64, 64], dtype="float32") + split_out = fluid.layers.split(data, dim=-1, num_or_sections=2) + out = fluid.layers.batch_norm(split_out[0], is_test=True) + self.feeds = { + "data": np.random.random([1, 3, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassSplitTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassInstanceNormTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 64, 64], dtype="float32") + fc_out = fluid.layers.fc(input=data, size=200) + param_attr = fluid.ParamAttr( + name='instance_norm_w', + initializer=fluid.initializer.Constant(value=1.0)) + bias_attr = fluid.ParamAttr( + name='instance_norm_b', + initializer=fluid.initializer.Constant(value=0.0)) + out = fluid.layers.instance_norm( + input=fc_out, param_attr=param_attr, bias_attr=bias_attr) + self.feeds = { + "data": np.random.random([1, 3, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassInstanceNormTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu, atol=1e-4, flatten=True) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassLayerNormTest(InferencePassTest): + def setUp(self): + self.set_params() + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 3, 64, 64], dtype="float32") + out = fluid.layers.layer_norm( + data, begin_norm_axis=self.begin_norm_axis) + self.feeds = { + "data": np.random.random([1, 3, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassLayerNormTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def set_params(self): + self.begin_norm_axis = 1 + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassLayerNormBeginNormAxis2Test( + TensorRTSubgraphPassLayerNormTest): + def set_params(self): + self.begin_norm_axis = 2 + + +class TensorRTSubgraphPassLayerNormBeginNormAxis3Test( + TensorRTSubgraphPassLayerNormTest): + def set_params(self): + self.begin_norm_axis = 3 + + +class TensorRTSubgraphPassElementwiseTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data1 = fluid.data( + name="data1", shape=[-1, 3, 64, 64], dtype="float32") + data2 = fluid.data( + name="data2", shape=[-1, 3, 64, 64], dtype="float32") + eltwise_out = self.append_eltwise(data1, data2) + out = fluid.layers.batch_norm(eltwise_out, is_test=True) + self.feeds = { + "data1": np.random.random([1, 3, 64, 64]).astype("float32"), + "data2": np.random.random([1, 3, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassElementwiseTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_add(x=data1, y=data2) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +class TensorRTSubgraphPassElementwiseMulTest( + TensorRTSubgraphPassElementwiseTest): + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_mul(x=data1, y=data2) + + +class TensorRTSubgraphPassShuffleChannelTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data( + name="data", shape=[-1, 6, 64, 64], dtype="float32") + sc_out = fluid.layers.shuffle_channel(data, group=3) + out = fluid.layers.batch_norm(sc_out, is_test=True) + self.feeds = { + "data": np.random.random([1, 6, 64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassShuffleChannelTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + self.assertTrue( + PassVersionChecker.IsCompatible('tensorrt_subgraph_pass')) + + +if __name__ == "__main__": + unittest.main() -- GitLab