提交 63ade29b 编写于 作者: C cryoco

add unittests and op version register for tensorrt_subgraph_pass

上级 c7e5cf16
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h" #include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h" #include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
...@@ -358,3 +359,31 @@ REGISTER_PASS(tensorrt_subgraph_pass, ...@@ -358,3 +359,31 @@ REGISTER_PASS(tensorrt_subgraph_pass,
.RequirePassAttr("max_batch_size") .RequirePassAttr("max_batch_size")
.RequirePassAttr("workspace_size") .RequirePassAttr("workspace_size")
.RequirePassAttr("min_subgraph_size"); .RequirePassAttr("min_subgraph_size");
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("pool2d", 0)
.EQ("relu", 0)
.EQ("softmax", 0)
.EQ("sigmoid", 0)
.EQ("hard_swish", 0)
.EQ("depthwise_conv2d", 0)
.EQ("batch_norm", 0)
.EQ("concat", 0)
.EQ("tanh", 0)
.EQ("pad", 0)
.EQ("elementwise_add", 0)
.EQ("elementwise_mul", 0)
.EQ("prelu", 0)
.LE("conv2d_transpose", 1)
.LE("leaky_relu", 1)
.EQ("fc", 0)
.EQ("shuffle_channel", 0)
.EQ("swish", 0)
.EQ("split", 0)
.EQ("instance_norm", 0)
.EQ("gelu", 0)
.EQ("layer_norm", 0)
.EQ("scale", 0));
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle { namespace paddle {
...@@ -30,9 +31,28 @@ class SoftMaxOpConverter : public OpConverter { ...@@ -30,9 +31,28 @@ class SoftMaxOpConverter : public OpConverter {
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
// Declare inputs // Declare inputs
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]); auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
nvinfer1::Dims input_shape = input1->getDimensions();
int input_dims = input_shape.nbDims;
int axis = BOOST_GET_CONST(int, op_desc.GetAttr("axis"));
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax, auto* layer = TRT_ENGINE_ADD_LAYER(engine_, SoftMax,
*const_cast<nvinfer1::ITensor*>(input1)); *const_cast<nvinfer1::ITensor*>(input1));
uint32_t axes = std::max(0, input_dims - 3);
if (!engine_->with_dynamic_shape()) {
if (axis == -1) {
axes = input_dims - 1;
} else {
axes = axis;
}
layer->setAxes(1 << axes);
} else {
if (axis == -1) {
axes = input_dims - 1;
} else {
axes = axis + 1;
}
layer->setAxes(1 << axes);
}
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode);
......
...@@ -107,7 +107,11 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, ...@@ -107,7 +107,11 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") { op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") {
std::vector<int> paddings = std::vector<int> paddings =
BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings")); BOOST_GET_CONST(std::vector<int>, desc.GetAttr("paddings"));
if (paddings.size() > 2) return false; std::string padding_algorithm =
BOOST_GET_CONST(std::string, desc.GetAttr("padding_algorithm"));
if (paddings.size() > 2 ||
(padding_algorithm == "SAME" && op_type != "pool2d"))
return false;
} }
if ((*teller)(op_type, desc, use_no_calib_int8)) return true; if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
} }
......
...@@ -50,10 +50,18 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -50,10 +50,18 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
float *output = reinterpret_cast<float **>(outputs)[0]; float *output = reinterpret_cast<float **>(outputs)[0];
int begin_norm_axis = begin_norm_axis_; int begin_norm_axis = begin_norm_axis_;
float eps = eps_; float eps = eps_;
int c = input_dims.d[begin_norm_axis - 1];
scale_t.Resize(framework::make_ddim({c})); std::vector<int> input_shape;
bias_t.Resize(framework::make_ddim({c})); input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]);
}
const auto input_ddim = framework::make_ddim(input_shape);
auto matrix_dim = framework::flatten_to_2d(input_ddim, begin_norm_axis - 1);
int feature_size = static_cast<int>(matrix_dim[1]);
scale_t.Resize(framework::make_ddim({feature_size}));
bias_t.Resize(framework::make_ddim({feature_size}));
mean_t.Resize(framework::make_ddim(mean_shape_)); mean_t.Resize(framework::make_ddim(mean_shape_));
variance_t.Resize(framework::make_ddim(variance_shape_)); variance_t.Resize(framework::make_ddim(variance_shape_));
int device_id; int device_id;
...@@ -63,15 +71,11 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs, ...@@ -63,15 +71,11 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id)); float *mean_d = mean_t.mutable_data<float>(platform::CUDAPlace(device_id));
float *variance_d = float *variance_d =
variance_t.mutable_data<float>(platform::CUDAPlace(device_id)); variance_t.mutable_data<float>(platform::CUDAPlace(device_id));
cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * c, cudaMemcpyAsync(scale_d, scale_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream); cudaMemcpyHostToDevice, stream);
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * c, cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream); cudaMemcpyHostToDevice, stream);
std::vector<int> input_shape;
input_shape.push_back(batch_size);
for (int i = 0; i < input_dims.nbDims; i++) {
input_shape.push_back(input_dims.d[i]);
}
paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm; paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d, layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps); variance_d, begin_norm_axis, eps);
......
...@@ -133,7 +133,7 @@ class InferencePassTest(unittest.TestCase): ...@@ -133,7 +133,7 @@ class InferencePassTest(unittest.TestCase):
for place_ in use_gpu: for place_ in use_gpu:
self.check_output_with_option(place_, atol) self.check_output_with_option(place_, atol)
def check_output_with_option(self, use_gpu, atol=1e-5): def check_output_with_option(self, use_gpu, atol=1e-5, flatten=False):
''' '''
Check whether calculating on CPU and GPU, enable TensorRT Check whether calculating on CPU and GPU, enable TensorRT
or disable TensorRT, enable MKLDNN or disable MKLDNN or disable TensorRT, enable MKLDNN or disable MKLDNN
...@@ -155,9 +155,13 @@ class InferencePassTest(unittest.TestCase): ...@@ -155,9 +155,13 @@ class InferencePassTest(unittest.TestCase):
format(device)) format(device))
for out, analysis_output in zip(outs, analysis_outputs): for out, analysis_output in zip(outs, analysis_outputs):
out = np.array(out)
if flatten:
out = out.flatten()
analysis_output = analysis_output.flatten()
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
np.array(out), analysis_output, atol=atol), out, analysis_output, atol=atol),
"Output has diff between inference and training forward at {} ". "Output has diff between inference and training forward at {} ".
format(device)) format(device))
...@@ -172,9 +176,13 @@ class InferencePassTest(unittest.TestCase): ...@@ -172,9 +176,13 @@ class InferencePassTest(unittest.TestCase):
"The number of outputs is different between GPU and TensorRT. ") "The number of outputs is different between GPU and TensorRT. ")
for out, tensorrt_output in zip(outs, tensorrt_outputs): for out, tensorrt_output in zip(outs, tensorrt_outputs):
out = np.array(out)
if flatten:
out = out.flatten()
tensorrt_output = tensorrt_output.flatten()
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
np.array(out), tensorrt_output, atol=atol), out, tensorrt_output, atol=atol),
"Output has diff between GPU and TensorRT. ") "Output has diff between GPU and TensorRT. ")
# Check whether the mkldnn results and the CPU results are the same. # Check whether the mkldnn results and the CPU results are the same.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TensorRTSubgraphPassConvTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d(
input=data,
num_filters=self.conv_num_filters,
filter_size=self.conv_filter_size,
groups=self.conv_groups,
padding=self.conv_padding,
bias_attr=False,
act=None)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassConvTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [conv_out]
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = [1, 1]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassConvValidPaddingTest(TensorRTSubgraphPassConvTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = 'VALID'
'''
# conv2d padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete.
class TensorRTSubgraphPassConvSamePaddingTest(InferencePassTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 3
self.conv_padding = 'SAME'
'''
class TensorRTSubgraphPassDepthwiseConvTest(TensorRTSubgraphPassConvTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 6
self.conv_padding = [1, 1]
class TensorRTSubgraphPassConvTransposeTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
conv_out = fluid.layers.conv2d_transpose(
input=data,
num_filters=self.conv_num_filters,
filter_size=self.conv_filter_size,
groups=self.conv_groups,
padding=self.conv_padding,
bias_attr=False,
act=None)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassConvTransposeTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [conv_out]
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = [1, 1]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassConvTransposeValidPaddingTest(
TensorRTSubgraphPassConvTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = 'VALID'
'''
# conv2d_transpose padded in 'SAME' mode is not yet supported in TRT, reopen this when support is complete.
class TensorRTSubgraphPassConvTransposeSamePaddingTest(TensorRTSubgraphPassConvTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = 'SAME'
'''
class TensorRTSubgraphPassDepthwiseConvTransposeTest(
TensorRTSubgraphPassConvTransposeTest):
def set_params(self):
self.conv_num_filters = 6
self.conv_filter_size = 6
self.conv_groups = 1
self.conv_padding = [1, 1]
class TensorRTSubgraphPassFcTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
fc_out = fluid.layers.fc(input=[data], act=None, size=1000)
reshape_out = fluid.layers.reshape(x=fc_out, shape=[1, 1000])
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassFcTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [reshape_out]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
# TRT output shape of fc is (1, 1000, 1, 1). To compare the output value only, flatten the results.
self.check_output_with_option(use_gpu, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassPoolTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
pool_out = fluid.layers.pool2d(
input=data,
pool_size=self.pool_size,
pool_type=self.pool_type,
pool_stride=self.pool_stride,
pool_padding=self.pool_padding,
global_pooling=self.global_pooling,
ceil_mode=self.ceil_mode,
exclusive=self.exclusive)
out = fluid.layers.batch_norm(pool_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassPoolTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassAvgPoolTest(TensorRTSubgraphPassPoolTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'avg'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
class TensorRTSubgraphPassGlobalPoolTest(TensorRTSubgraphPassPoolTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = True
self.ceil_mode = False
self.exclusive = False
class TensorRTSubgraphPassCeilPoolTest(TensorRTSubgraphPassPoolTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = True
self.exclusive = False
class TensorRTSubgraphPassExclusivePoolTest(TensorRTSubgraphPassPoolTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 0
self.global_pooling = False
self.ceil_mode = False
self.exclusive = True
class TensorRTSubgraphPassSamePaddingPoolTest(InferencePassTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 'SAME'
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
class TensorRTSubgraphPassValidPaddingPoolTest(InferencePassTest):
def set_params(self):
self.pool_size = 2
self.pool_type = 'max'
self.pool_stride = 1
self.pool_padding = 'VALID'
self.global_pooling = False
self.ceil_mode = False
self.exclusive = False
class TensorRTSubgraphPassActivationTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
act_out = self.append_act(data)
out = fluid.layers.batch_norm(act_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassActivationTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def append_act(self, x):
return fluid.layers.relu(x)
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassLeakyReluTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.leaky_relu(x)
class TensorRTSubgraphPassRelu6Test(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.relu6(x)
class TensorRTSubgraphPassSoftMaxTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.softmax(x)
class TensorRTSubgraphPassSigmoidTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.sigmoid(x)
class TensorRTSubgraphPassHardSwishTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.hard_swish(x)
class TensorRTSubgraphPassHardSigmoidTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.hard_sigmoid(x)
class TensorRTSubgraphPassTanhTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.tanh(x)
class TensorRTSubgraphPassSwishTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.swish(x)
class TensorRTSubgraphPassPreluAllTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.prelu(x, mode='all')
class TensorRTSubgraphPassPreluChannelTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.prelu(x, mode='channel')
class TensorRTSubgraphPassPreluElementTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.prelu(x, mode='element')
class TensorRTSubgraphPassGeluTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.gelu(x)
class TensorRTSubgraphPassConcatTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data1 = fluid.data(
name="data1", shape=[-1, 3, 64, 64], dtype="float32")
data2 = fluid.data(
name="data2", shape=[-1, 3, 64, 64], dtype="float32")
concat_out = fluid.layers.concat([data1, data2], axis=2)
out = fluid.layers.batch_norm(concat_out, is_test=True)
self.feeds = {
"data1": np.random.random([1, 3, 64, 64]).astype("float32"),
"data2": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassConcatTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassSplitTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
split_out = fluid.layers.split(data, dim=-1, num_or_sections=2)
out = fluid.layers.batch_norm(split_out[0], is_test=True)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassSplitTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassInstanceNormTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=200)
param_attr = fluid.ParamAttr(
name='instance_norm_w',
initializer=fluid.initializer.Constant(value=1.0))
bias_attr = fluid.ParamAttr(
name='instance_norm_b',
initializer=fluid.initializer.Constant(value=0.0))
out = fluid.layers.instance_norm(
input=fc_out, param_attr=param_attr, bias_attr=bias_attr)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassInstanceNormTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, atol=1e-4, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassLayerNormTest(InferencePassTest):
def setUp(self):
self.set_params()
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32")
out = fluid.layers.layer_norm(
data, begin_norm_axis=self.begin_norm_axis)
self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassLayerNormTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def set_params(self):
self.begin_norm_axis = 1
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassLayerNormBeginNormAxis2Test(
TensorRTSubgraphPassLayerNormTest):
def set_params(self):
self.begin_norm_axis = 2
class TensorRTSubgraphPassLayerNormBeginNormAxis3Test(
TensorRTSubgraphPassLayerNormTest):
def set_params(self):
self.begin_norm_axis = 3
class TensorRTSubgraphPassElementwiseTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data1 = fluid.data(
name="data1", shape=[-1, 3, 64, 64], dtype="float32")
data2 = fluid.data(
name="data2", shape=[-1, 3, 64, 64], dtype="float32")
eltwise_out = self.append_eltwise(data1, data2)
out = fluid.layers.batch_norm(eltwise_out, is_test=True)
self.feeds = {
"data1": np.random.random([1, 3, 64, 64]).astype("float32"),
"data2": np.random.random([1, 3, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassElementwiseTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_add(x=data1, y=data2)
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
class TensorRTSubgraphPassElementwiseMulTest(
TensorRTSubgraphPassElementwiseTest):
def append_eltwise(self, data1, data2):
return fluid.layers.elementwise_mul(x=data1, y=data2)
class TensorRTSubgraphPassShuffleChannelTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 6, 64, 64], dtype="float32")
sc_out = fluid.layers.shuffle_channel(data, group=3)
out = fluid.layers.batch_norm(sc_out, is_test=True)
self.feeds = {
"data": np.random.random([1, 6, 64, 64]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = TensorRTSubgraphPassShuffleChannelTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册