未验证 提交 bfb5cf55 编写于 作者: Z zlsh80826 提交者: GitHub

[Paddle-TRT] trt affine channel converter (#31628)

* trt affine channel converter

* add trt affine channel base test

* add trt affine channel NHWC

* remove asterisk for python2 compatibility

* trt affine channel converter

* add trt affine channel base test

* add trt affine channel NHWC

* remove asterisk for python2 compatibility

* fix rebase

* move LodTensor to Tensor

* add dbg info

* affine channel converter only support NCHW

* scale,bias are parameters, use create_parameters api

* reduce test input size to not exceed the timelimit of ci

* refine affine channel unittest and add serialization/dynamic test

* change super to InferencePassTest for python2 compatibility

* change super to InferencePassTest for python2 compatibility

* fix affine channel fp16 serialize setting
上级 b47478ef
......@@ -1192,6 +1192,7 @@ USE_TRT_CONVERTER(scale);
USE_TRT_CONVERTER(stack);
USE_TRT_CONVERTER(clip);
USE_TRT_CONVERTER(gather);
USE_TRT_CONVERTER(affine_channel);
USE_TRT_CONVERTER(multiclass_nms);
USE_TRT_CONVERTER(nearest_interp);
#endif
......
......@@ -6,6 +6,7 @@ nv_library(tensorrt_converter
shuffle_channel_op.cc swish_op.cc instance_norm_op.cc stack_op.cc transpose_op.cc flatten_op.cc
emb_eltwise_layernorm.cc skip_layernorm.cc scale_op.cc slice_op.cc hard_sigmoid_op.cc hard_swish_op.cc clip_op.cc
gather_op.cc
affine_channel_op.cc
multiclass_nms_op.cc
nearest_interp_op.cc
DEPS tensorrt_engine tensorrt_plugin operator scope framework_proto op_registry)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
/*
* Affine Channel Op
*/
class AffineChannelOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid affine_channel op to tensorrt scale nd layer";
framework::OpDesc op_desc(op, nullptr);
std::string input_name = op_desc.Input("X").front();
std::string scale_name = op_desc.Input("Scale").front();
std::string bias_name = op_desc.Input("Bias").front();
std::string output_name = op_desc.Output("Out").front();
auto input_tensor = engine_->GetITensor(input_name);
auto idim = input_tensor->getDimensions();
auto* scale_v = scope.FindVar(scale_name);
auto* scale_t = scale_v->GetMutable<framework::LoDTensor>();
float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t, false);
auto* bias_v = scope.FindVar(bias_name);
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false);
auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, op_desc.GetAttr("data_layout")));
PADDLE_ENFORCE_EQ(
data_layout, framework::DataLayout::kNCHW,
platform::errors::InvalidArgument(
"TensorRT affine channel converter can only convert NCHW format. "
"Other format should be run in fluid mode. Report a bug on github "
"issue if you see this line."));
// tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0)
const int channel_axis = engine_->with_dynamic_shape();
TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(scale_ptr),
(size_t)idim.d[channel_axis]};
TensorRTEngine::Weight bias_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_ptr),
(size_t)idim.d[channel_axis]};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0};
auto layer = TRT_ENGINE_ADD_LAYER(engine_, ScaleNd, *input_tensor,
nvinfer1::ScaleMode::kCHANNEL,
bias_weights.get(), scale_weights.get(),
power_weights.get(), channel_axis);
RreplenishLayerAndOutput(layer, "affine_channel", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(affine_channel, AffineChannelOpConverter);
......@@ -111,6 +111,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"flatten2",
"flatten",
"gather",
"affine_channel",
"multiclass_nms",
"nearest_interp",
};
......@@ -196,6 +197,13 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
if (!with_dynamic_shape || desc.Input("Axis").size() > 0) return false;
}
if (op_type == "affine_channel") {
if (!desc.HasAttr("data_layout")) return false;
auto data_layout = framework::StringToDataLayout(
BOOST_GET_CONST(std::string, desc.GetAttr("data_layout")));
if (data_layout != framework::DataLayout::kNCHW) return false;
}
if (op_type == "multiclass_nms") {
if (with_dynamic_shape) return false;
auto* block = desc.Block();
......@@ -238,6 +246,7 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
return false;
}
}
if (op_type == "nearest_interp") {
std::vector<std::string> attrs{"data_layout", "interp_method",
"align_corners", "scale",
......@@ -254,7 +263,6 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
BOOST_GET_CONST(std::string, desc.GetAttr("interp_method"));
if (interp_method != "nearest") return false;
}
if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
}
return false;
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import itertools
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.core import AnalysisConfig
class TRTAffineChannelTest(InferencePassTest):
def setUp(self):
self.bs = 2
self.channel = 8
self.height = 16
self.width = 16
self.data_layout = 'NCHW'
self.precision = AnalysisConfig.Precision.Float32
self.serialize = False
self.enable_trt = True
def build(self):
# set min_graph_size to 2,
# because affine channel doesn't support nhwc format
self.trt_parameters = InferencePassTest.TensorRTParam(
1 << 30, self.bs, 2, self.precision, self.serialize, False)
with fluid.program_guard(self.main_program, self.startup_program):
if self.data_layout == 'NCHW':
shape = [-1, self.channel, self.height, self.width]
else:
shape = [-1, self.height, self.width, self.channel]
data = fluid.data(name='in', shape=shape, dtype='float32')
# set scale, bias by constant
scale = fluid.layers.create_parameter(
shape=[self.channel],
dtype='float32',
default_initializer=fluid.initializer.Constant(2.))
bias = fluid.layers.create_parameter(
shape=[self.channel],
dtype='float32',
default_initializer=fluid.initializer.Constant(.5))
affine_channel_out = fluid.layers.affine_channel(
data, scale=scale, bias=bias, data_layout=self.data_layout)
out = fluid.layers.batch_norm(affine_channel_out, is_test=True)
shape[0] = self.bs
self.feeds = {'in': np.random.random(shape).astype('float32'), }
self.fetch_list = [out]
def check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
atol = 1e-5
if self.trt_parameters.precision == AnalysisConfig.Precision.Half:
atol = 1e-3
self.check_output_with_option(use_gpu, atol, flatten=True)
self.assertTrue(
PassVersionChecker.IsCompatible('tensorrt_subgraph_pass'))
def run_test(self):
self.build()
self.check_output()
def run_test_all(self):
precision_opt = [
AnalysisConfig.Precision.Float32, AnalysisConfig.Precision.Half
]
serialize_opt = [False, True]
if self.data_layout == 'NCHW':
min_shape = [
self.bs, self.channel, self.height // 2, self.width // 2
]
max_shape = [self.bs, self.channel, self.height * 2, self.width * 2]
opt_shape = [self.bs, self.channel, self.height, self.width]
if self.data_layout == 'NHWC':
min_shape = [
self.bs, self.height // 2, self.width // 2, self.channel
]
max_shape = [self.bs, self.height * 2, self.width * 2, self.channel]
opt_shape = [self.bs, self.height, self.width, self.channel]
dynamic_shape_profile = InferencePassTest.DynamicShapeParam({
'in': min_shape
}, {'in': max_shape}, {'in': opt_shape}, False)
dynamic_shape_opt = [None, dynamic_shape_profile]
for precision, serialize, dynamic_shape in itertools.product(
precision_opt, serialize_opt, dynamic_shape_opt):
self.precision = precision
self.serialize = serialize
self.dynamic_shape_params = dynamic_shape
self.run_test()
def test_base(self):
self.run_test()
def test_fp16(self):
self.precision = AnalysisConfig.Precision.Half
self.run_test()
def test_serialize(self):
self.serialize = True
self.run_test()
def test_dynamic(self):
self.dynamic_shape_params = InferencePassTest.DynamicShapeParam({
'in': [self.bs, self.channel, self.height // 2, self.width // 2]
}, {'in': [self.bs, self.channel, self.height * 2, self.width * 2]
}, {'in': [self.bs, self.channel, self.height, self.width]}, False)
self.run_test()
def test_nchw_all(self):
self.run_test_all()
def test_nhwc(self):
self.data_layout = 'NHWC'
self.run_test_all()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册