未验证 提交 62d848de 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle-TRT]fix trt-converter-fc_op (#32671)

* [Paddle-TRT]fix fc_op

* [Paddle-TRT]fix fc_op

* [Paddle-TRT]fix fc_op

* test_trt_subgraph_pass.py

* fix elementwise_op

* fix elementwise_op

* fix elementwise_op

* fix elementwise_op.cc

* op_teller.cc
上级 c1c18b08
...@@ -66,6 +66,25 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -66,6 +66,25 @@ class ElementwiseWeightOpConverter : public OpConverter {
0}; 0};
TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr, TensorRTEngine::Weight power_weights{nvinfer1::DataType::kFLOAT, nullptr,
0}; 0};
nvinfer1::IShuffleLayer* expand_layer = nullptr;
nvinfer1::IShuffleLayer* squeeze_layer = nullptr;
int dynamic_shape_offset = engine_->with_dynamic_shape() ? 1 : 0;
auto input_dim = X->getDimensions();
if (input_dim.nbDims < 3 + dynamic_shape_offset) {
nvinfer1::Dims expand_shape;
expand_shape.nbDims = 3 + dynamic_shape_offset;
for (int i = 0; i < expand_shape.nbDims; i++) {
if (i < input_dim.nbDims) {
expand_shape.d[i] = input_dim.d[i] < 0 ? 0 : input_dim.d[i];
} else {
expand_shape.d[i] = 1;
}
}
expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
expand_layer->setReshapeDimensions(expand_shape);
X = expand_layer->getOutput(0);
}
if (op_type_ == "add") { if (op_type_ == "add") {
nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER( nvinfer1::IScaleLayer* scale_layer = TRT_ENGINE_ADD_LAYER(
engine_, Scale, *X, scale_mode, shift_weights.get(), engine_, Scale, *X, scale_mode, shift_weights.get(),
...@@ -77,7 +96,17 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -77,7 +96,17 @@ class ElementwiseWeightOpConverter : public OpConverter {
shift_weights.get(), power_weights.get()); shift_weights.get(), power_weights.get());
layer = scale_layer; layer = scale_layer;
} }
if (input_dim.nbDims < 3 + dynamic_shape_offset) {
nvinfer1::Dims squeeze_shape;
squeeze_shape.nbDims = input_dim.nbDims;
for (int i = 0; i < squeeze_shape.nbDims; i++) {
squeeze_shape.d[i] = input_dim.d[i] < 0 ? 0 : input_dim.d[i];
}
squeeze_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0)));
squeeze_layer->setReshapeDimensions(squeeze_shape);
layer = static_cast<nvinfer1::ILayer*>(squeeze_layer);
}
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, {output_name}, RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, {output_name},
test_mode); test_mode);
......
...@@ -37,7 +37,7 @@ class FcOpConverter : public OpConverter { ...@@ -37,7 +37,7 @@ class FcOpConverter : public OpConverter {
const framework::Scope& scope, bool test_mode) override { const framework::Scope& scope, bool test_mode) override {
VLOG(3) << "convert a fluid fc op to tensorrt fc layer without bias"; VLOG(3) << "convert a fluid fc op to tensorrt fc layer without bias";
framework::OpDesc op_desc(op, nullptr); framework::OpDesc op_desc(op, nullptr);
auto output_name = op_desc.Output("Out").front();
auto input_names = op_desc.InputNames(); auto input_names = op_desc.InputNames();
bool with_bias = input_names.size() >= 3; bool with_bias = input_names.size() >= 3;
std::string w_name = "Y"; std::string w_name = "Y";
...@@ -54,7 +54,7 @@ class FcOpConverter : public OpConverter { ...@@ -54,7 +54,7 @@ class FcOpConverter : public OpConverter {
Y_v, platform::errors::NotFound( Y_v, platform::errors::NotFound(
"Can not find %s presistale var of fc in scope.", w_name)); "Can not find %s presistale var of fc in scope.", w_name));
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>(); auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
const int x_num_col_dims = int x_num_col_dims =
op_desc.HasAttr("x_num_col_dims") op_desc.HasAttr("x_num_col_dims")
? BOOST_GET_CONST(int, op_desc.GetAttr("x_num_col_dims")) ? BOOST_GET_CONST(int, op_desc.GetAttr("x_num_col_dims"))
: (op_desc.HasAttr("in_num_col_dims") : (op_desc.HasAttr("in_num_col_dims")
...@@ -106,8 +106,8 @@ class FcOpConverter : public OpConverter { ...@@ -106,8 +106,8 @@ class FcOpConverter : public OpConverter {
auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output,
TensorRTEngine::Weight& weight, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) { TensorRTEngine::Weight& bias) {
nvinfer1::ILayer* fc_layer = nullptr;
if (enable_int8) { if (enable_int8) {
// add conv layer
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true, op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -115,22 +115,46 @@ class FcOpConverter : public OpConverter { ...@@ -115,22 +115,46 @@ class FcOpConverter : public OpConverter {
float out_scale = float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, auto* fc_layer_int8 =
nv_ksize, weight.get(), bias.get()); TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
engine_->SetTensorDynamicRange(fc_layer->getOutput(0), out_scale); nv_ksize, weight.get(), bias.get());
} else { engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale);
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *inputs, if (activation_type == "relu") {
n_output, weight.get(), bias.get()); nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
} engine_, Activation, *(fc_layer_int8->getOutput(0)),
nvinfer1::ActivationType::kRELU);
auto output_name = op_desc.Output("Out").front(); RreplenishLayerAndOutput(relu_layer_int8, "relu_after_fc_shuffle",
if (activation_type == "relu") { {output_name}, test_mode);
nvinfer1::IActivationLayer* relu_layer = } else {
TRT_ENGINE_ADD_LAYER(engine_, Activation, *(fc_layer->getOutput(0)), RreplenishLayerAndOutput(fc_layer_int8, "shuffle_after_fc",
nvinfer1::ActivationType::kRELU); {output_name}, test_mode);
RreplenishLayerAndOutput(relu_layer, "fc", {output_name}, test_mode); }
} else { } else {
RreplenishLayerAndOutput(fc_layer, "fc", {output_name}, test_mode); // add fc layer
auto* fc_layer_before =
TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *inputs, n_output,
weight.get(), bias.get());
fc_layer_before->setName(
("fc_layer_before(Output: " + output_name + ")").c_str());
// add shuffle after fc
nvinfer1::Dims reshape_after_fc_dim;
reshape_after_fc_dim.nbDims = x_num_col_dims + 1;
for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) {
reshape_after_fc_dim.d[i] = 0;
}
auto* fc_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, Shuffle, *fc_layer_before->getOutput(0));
fc_layer_float->setReshapeDimensions(reshape_after_fc_dim);
if (activation_type == "relu") {
nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(fc_layer_float->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer_float, "relu_after_fc_shuffle",
{output_name}, test_mode);
} else {
RreplenishLayerAndOutput(fc_layer_float, "shuffle_after_fc",
{output_name}, test_mode);
}
} }
}; };
...@@ -157,153 +181,43 @@ class FcOpConverter : public OpConverter { ...@@ -157,153 +181,43 @@ class FcOpConverter : public OpConverter {
static_cast<void*>(bias_data), static_cast<void*>(bias_data),
static_cast<size_t>(bias_num)}; static_cast<size_t>(bias_num)};
if (engine_->with_dynamic_shape()) { auto x_dim = X->getDimensions();
// not NCHW layout, but NLP layout with added 'x 1 x 1' // Running the TRT Static Shape mode: x_num_col_dims-1
auto x_dim = X->getDimensions(); if (!engine_->with_dynamic_shape()) {
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && x_num_col_dims--;
x_dim.d[2] == 1 && x_dim.d[3] == 1 && x_num_col_dims == 2) {
// fc which is just after self attention
regist_fc(X, n_output, weight, bias);
return;
}
PADDLE_ENFORCE_LE(
x_dim.nbDims - x_num_col_dims, 3,
platform::errors::InvalidArgument(
"Params and input dims mismatch. Paddle-TRT FC "
"converter expects x_dim.nbDims - x_num_col_dims <= 3, but "
"x_dim.nbDims = %d, x_num_col_dims = %d.",
x_dim.nbDims, x_num_col_dims));
auto output_name = op_desc.Output("Out").front();
// add shuffle before fc
nvinfer1::Dims reshape_before_fc_dim;
// padding shape "x 1 x 1"
int padding_length = 3 - (x_dim.nbDims - x_num_col_dims);
reshape_before_fc_dim.nbDims = x_dim.nbDims + padding_length;
int cur_dim_index = reshape_before_fc_dim.nbDims - 1;
while (padding_length-- > 0) {
reshape_before_fc_dim.d[cur_dim_index--] = 1;
}
while (cur_dim_index >= 0) {
reshape_before_fc_dim.d[cur_dim_index--] = 0;
}
auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim);
reshape_before_fc_layer->setName(
("shuffle_before_fc(Output: " + output_name + ")").c_str());
// add fc layer
auto* fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, FullyConnected, *reshape_before_fc_layer->getOutput(0),
n_output, weight.get(), bias.get());
fc_layer->setName(("fc_layer(Output: " + output_name + ")").c_str());
// add shuffle after fc
nvinfer1::Dims reshape_after_fc_dim;
reshape_after_fc_dim.nbDims = x_num_col_dims + 1;
for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) {
reshape_after_fc_dim.d[i] = 0;
}
auto* reshape_after_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0));
reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim);
if (activation_type == "relu") {
reshape_after_fc_layer->setName(
("shuffle_after_fc(Output: " + output_name + ")").c_str());
nvinfer1::IActivationLayer* relu_layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *(reshape_after_fc_layer->getOutput(0)),
nvinfer1::ActivationType::kRELU);
RreplenishLayerAndOutput(relu_layer, "relu_after_fc_shuffle",
{output_name}, test_mode);
} else {
RreplenishLayerAndOutput(reshape_after_fc_layer, "shuffle_after_fc",
{output_name}, test_mode);
}
return;
} }
// in order to handle situations in NLP models(input dims < 3, PADDLE_ENFORCE_GT(
// x_num_col_dims != 1, etc.), reshape input to perform FC correctly. x_dim.nbDims, x_num_col_dims,
auto* reshape_itensor = X; platform::errors::InvalidArgument(
int input_dims = X->getDimensions().nbDims; "Params and input dims mismatch. Paddle-TRT FC "
auto input_d = X->getDimensions().d; "converter expects x_dim.nbDims > x_num_col_dims, but "
int reshape_dim3[3] = {0}; "x_dim.nbDims : %d, x_num_col_dims : %d.",
int reshape_dim4[4] = {0}; x_dim.nbDims, x_num_col_dims));
PADDLE_ENFORCE_LE(x_num_col_dims, input_dims, // add shuffle before fc
platform::errors::InvalidArgument( nvinfer1::Dims reshape_before_fc_dim;
"Params and input dims mismatch. Paddle-TRT FC " reshape_before_fc_dim.nbDims = x_num_col_dims + 3;
"converter expects x_num_col_dims <= input dims")); // padding shape "* x q x 1 x 1"
if (x_num_col_dims == 1) { for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) {
if (input_dims == 4) { reshape_before_fc_dim.d[i] = 1;
PADDLE_ENFORCE_EQ( }
input_d[3], 1, for (int i = 0; i < x_dim.nbDims; i++) {
platform::errors::InvalidArgument( if (i < x_num_col_dims) {
"Invalid dimensions. When x_num_col_dims equals to 1 and input " reshape_before_fc_dim.d[i] = 0;
"dims equals to 4, the last dim of input must be 1, but got %d",
input_d[3]));
}
if (enable_int8) {
reshape_dim3[0] = 1;
for (int i = 0; i < 3; i++) {
reshape_dim3[0] *= input_d[i];
if (i > 0) {
reshape_dim3[i] = 1;
}
}
} else {
for (int i = 0; i < 3; i++) {
if (i < input_dims) {
reshape_dim3[i] = input_d[i];
} else {
reshape_dim3[i] = 1;
}
}
}
nvinfer1::Dims3 reshape_dim(reshape_dim3[0], reshape_dim3[1],
reshape_dim3[2]);
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
reshape_layer->setReshapeDimensions(reshape_dim);
reshape_itensor = reshape_layer->getOutput(0);
if (enable_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
}
} else {
PADDLE_ENFORCE_NE(input_dims, 1,
platform::errors::InvalidArgument(
"Invalid dimensions. When x_num_col_dims equals to "
"2, input_dims should not be 1"));
if (enable_int8) {
for (int i = 0; i < 4; i++) {
if (i == 0) {
reshape_dim4[i] = input_d[i];
} else {
reshape_dim4[i] = 1;
if (i < input_dims) {
reshape_dim4[1] *= input_d[i];
}
}
}
} else { } else {
for (int i = 0; i < 4; i++) { if (x_dim.d[i] < 0) {
if (i < input_dims) { reshape_before_fc_dim.d[x_num_col_dims] = -1;
reshape_dim4[i] = input_d[i]; break;
} else {
reshape_dim4[i] = 1;
}
} }
reshape_before_fc_dim.d[x_num_col_dims] *= x_dim.d[i];
} }
nvinfer1::Dims4 reshape_dim(reshape_dim4[0], reshape_dim4[1], }
reshape_dim4[2], reshape_dim4[3]); auto* reshape_before_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim);
reshape_layer->setReshapeDimensions(reshape_dim); reshape_before_fc_layer->setName(
reshape_itensor = reshape_layer->getOutput(0); ("shuffle_before_fc(Output: " + output_name + ")").c_str());
if (enable_int8) { auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
engine_->SetTensorDynamicRange(reshape_itensor, in_scale); if (enable_int8) {
} engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
} }
regist_fc(reshape_itensor, n_output, weight, bias); regist_fc(reshape_itensor, n_output, weight, bias);
} }
......
...@@ -633,6 +633,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, ...@@ -633,6 +633,20 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
} }
} }
if (op_type == "fc") {
int x_num_col_dims =
desc.HasAttr("x_num_col_dims")
? BOOST_GET_CONST(int, desc.GetAttr("x_num_col_dims"))
: (desc.HasAttr("in_num_col_dims")
? BOOST_GET_CONST(int, desc.GetAttr("in_num_col_dims"))
: 1);
if (x_num_col_dims < 1) {
VLOG(3) << "converter expects x_num_col_dims >= 1, "
"but x_num_col_dims = %d.";
return false;
}
}
if ((*teller)(op_type, desc, use_no_calib_int8)) return true; if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
} }
return false; return false;
......
...@@ -31,10 +31,7 @@ class FCFusePassTRTTest(InferencePassTest): ...@@ -31,10 +31,7 @@ class FCFusePassTRTTest(InferencePassTest):
size=128, size=128,
num_flatten_dims=1, num_flatten_dims=1,
act="relu") act="relu")
fc_out2 = fluid.layers.fc(input=fc_out1, out = fluid.layers.softmax(input=fc_out1)
size=32,
num_flatten_dims=1)
out = fluid.layers.softmax(input=fc_out2)
self.feeds = { self.feeds = {
"data": np.random.random((32, 128, 2, 2)).astype("float32") "data": np.random.random((32, 128, 2, 2)).astype("float32")
...@@ -55,6 +52,60 @@ class FCFusePassTRTTest(InferencePassTest): ...@@ -55,6 +52,60 @@ class FCFusePassTRTTest(InferencePassTest):
self.check_output_with_option(use_gpu[i]) self.check_output_with_option(use_gpu[i])
class FCFusePassTRTStaticDims4Cols1Test(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[32, 128, 32, 8], dtype="float32")
fc_out1 = fluid.layers.fc(input=data,
size=64,
num_flatten_dims=1,
act="relu")
out = fluid.layers.softmax(input=fc_out1)
self.feeds = {
"data": np.random.random((32, 128, 32, 8)).astype("float32")
}
self.enable_trt = True
self.trt_parameters = FCFusePassTRTStaticDims4Cols1Test.TensorRTParam(
1 << 30, 32, 2, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
use_gpu = [False]
if core.is_compiled_with_cuda():
use_gpu.append(True)
for i in range(len(use_gpu)):
self.check_output_with_option(use_gpu[i])
class FCFusePassTRTStaticDims4Cols2Test(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[3, 24, 16, 16], dtype="float32")
fc_out1 = fluid.layers.fc(input=data,
size=32,
num_flatten_dims=2,
act="relu")
out = fluid.layers.softmax(input=fc_out1)
self.feeds = {
"data": np.random.random((3, 24, 16, 16)).astype("float32")
}
self.enable_trt = True
self.trt_parameters = FCFusePassTRTStaticDims4Cols2Test.TensorRTParam(
1 << 30, 32, 2, AnalysisConfig.Precision.Float32, False, False)
self.fetch_list = [out]
def test_check_output(self):
use_gpu = [False]
if core.is_compiled_with_cuda():
use_gpu.append(True)
for i in range(len(use_gpu)):
self.check_output_with_option(use_gpu[i])
class FCFusePassTRTDynamicDims2Test(InferencePassTest): class FCFusePassTRTDynamicDims2Test(InferencePassTest):
def setUp(self): def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
......
...@@ -262,7 +262,6 @@ class TensorRTSubgraphPassInstanceNormTest(InferencePassTest): ...@@ -262,7 +262,6 @@ class TensorRTSubgraphPassInstanceNormTest(InferencePassTest):
with fluid.program_guard(self.main_program, self.startup_program): with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data( data = fluid.data(
name="data", shape=[-1, 3, 64, 64], dtype="float32") name="data", shape=[-1, 3, 64, 64], dtype="float32")
fc_out = fluid.layers.fc(input=data, size=200)
param_attr = fluid.ParamAttr( param_attr = fluid.ParamAttr(
name='instance_norm_w', name='instance_norm_w',
initializer=fluid.initializer.Constant(value=1.0)) initializer=fluid.initializer.Constant(value=1.0))
...@@ -270,7 +269,7 @@ class TensorRTSubgraphPassInstanceNormTest(InferencePassTest): ...@@ -270,7 +269,7 @@ class TensorRTSubgraphPassInstanceNormTest(InferencePassTest):
name='instance_norm_b', name='instance_norm_b',
initializer=fluid.initializer.Constant(value=0.0)) initializer=fluid.initializer.Constant(value=0.0))
out = fluid.layers.instance_norm( out = fluid.layers.instance_norm(
input=fc_out, param_attr=param_attr, bias_attr=bias_attr) input=data, param_attr=param_attr, bias_attr=bias_attr)
self.feeds = { self.feeds = {
"data": np.random.random([1, 3, 64, 64]).astype("float32"), "data": np.random.random([1, 3, 64, 64]).astype("float32"),
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册