diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index 2fa2cb97cc32b4243e8b825605e9689ecdab615f..ce6644cad4200f10a3e432469fda6c964dd7a94f 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -31,29 +34,57 @@ namespace tensorrt { class FcOpConverter : public OpConverter { public: nvinfer1::ILayer* reshape_before_fc(nvinfer1::ITensor* before_fc, - nvinfer1::Dims x_dim, int x_num_col_dims, + nvinfer1::Dims x_dim, + int x_num_col_dims, std::string output_name) { // add shuffle before fc nvinfer1::Dims reshape_before_fc_dim; reshape_before_fc_dim.nbDims = x_num_col_dims + 3; // padding shape "* x q x 1 x 1" - for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { - reshape_before_fc_dim.d[i] = 1; - } - for (int i = 0; i < x_dim.nbDims; i++) { - if (i < x_num_col_dims) { - reshape_before_fc_dim.d[i] = 0; - } else { - if (x_dim.d[i] < 0) { - reshape_before_fc_dim.d[x_num_col_dims] = -1; - break; + + nvinfer1::ITensor* filal_reshape_before_fc_shape_tensor = nullptr; + + if (!engine_->with_dynamic_shape()) { + for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { + reshape_before_fc_dim.d[i] = 1; + } + for (int i = 0; i < x_dim.nbDims; i++) { + if (i < x_num_col_dims) { + reshape_before_fc_dim.d[i] = 0; + } else { + reshape_before_fc_dim.d[x_num_col_dims] *= x_dim.d[i]; } - reshape_before_fc_dim.d[x_num_col_dims] *= x_dim.d[i]; } + } else { + std::vector reshape_before_fc_shape_tensor; + nvinfer1::ITensor* input_shape_tensor = Shape(before_fc); + + for (int i = 0; i < reshape_before_fc_dim.nbDims; i++) { + reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1)); + } + for (int i = 0; i < x_dim.nbDims; i++) { + if (i < x_num_col_dims) { + reshape_before_fc_shape_tensor[i] = + GetEleTensorOfShape(input_shape_tensor, i); + } else { + reshape_before_fc_shape_tensor[x_num_col_dims] = + Prod(GetEleTensorOfShape(input_shape_tensor, i), + reshape_before_fc_shape_tensor[x_num_col_dims]); + } + } + filal_reshape_before_fc_shape_tensor = + Concat(reshape_before_fc_shape_tensor); } + auto* reshape_before_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *before_fc); - reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + if (!engine_->with_dynamic_shape()) { + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + } else { + reshape_before_fc_layer->setInput(1, + *filal_reshape_before_fc_shape_tensor); + } + reshape_before_fc_layer->setName( ("fc_op_reshape_before_fc: Shuffle (Output: " + output_name + ")") .c_str()); @@ -61,21 +92,39 @@ class FcOpConverter : public OpConverter { } nvinfer1::ILayer* reshape_after_fc(nvinfer1::ITensor* after_fc, - nvinfer1::Dims x_dim, int x_num_col_dims) { + nvinfer1::Dims x_dim, + int x_num_col_dims) { // add shuffle after fc nvinfer1::Dims reshape_after_fc_dim; reshape_after_fc_dim.nbDims = x_num_col_dims + 1; - for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { - reshape_after_fc_dim.d[i] = 0; + + nvinfer1::ITensor* filal_reshape_after_fc_shape_tensor = nullptr; + + if (!engine_->with_dynamic_shape()) { + for (int i = 0; i < reshape_after_fc_dim.nbDims; i++) { + reshape_after_fc_dim.d[i] = 0; + } + } else { + std::vector gather_indices(x_num_col_dims + 1); + std::iota(gather_indices.begin(), gather_indices.end(), 0); + filal_reshape_after_fc_shape_tensor = + Gather(Shape(after_fc), gather_indices); } + auto* reshape_after_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *after_fc); - reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim); + if (!engine_->with_dynamic_shape()) { + reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim); + } else { + reshape_after_fc_layer->setInput(1, *filal_reshape_after_fc_shape_tensor); + } + return reshape_after_fc_layer; } void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { + const framework::Scope& scope, + bool test_mode) override { VLOG(3) << "convert a fluid fc op to tensorrt fc layer without bias"; framework::OpDesc op_desc(op, nullptr); auto output_name = op_desc.Output("Out").front(); @@ -93,8 +142,9 @@ class FcOpConverter : public OpConverter { // Declare weights auto* Y_v = scope.FindVar(op_desc.Input(w_name).front()); PADDLE_ENFORCE_NOT_NULL( - Y_v, platform::errors::NotFound( - "Can not find %s presistale var of fc in scope.", w_name)); + Y_v, + platform::errors::NotFound( + "Can not find %s presistale var of fc in scope.", w_name)); auto* Y_t = Y_v->GetMutable(); int x_num_col_dims = op_desc.HasAttr("x_num_col_dims") @@ -125,7 +175,8 @@ class FcOpConverter : public OpConverter { } weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t); - PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL, + PADDLE_ENFORCE_EQ(Y_t->dims().size(), + 2UL, platform::errors::InvalidArgument( "The fc's weight should be a matrix with 2 dims, but " "it's %d-dimensional.", @@ -140,7 +191,8 @@ class FcOpConverter : public OpConverter { } }; - auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, + auto regist_fc = [&](nvinfer1::ITensor* inputs, + int n_output, TensorRTEngine::Weight& weight, TensorRTEngine::Weight& bias) { if (enable_int8 || support_int8) { @@ -148,7 +200,8 @@ class FcOpConverter : public OpConverter { float out_scale = 0; if (enable_int8) { PADDLE_ENFORCE_EQ( - op_desc.HasAttr("out_threshold"), true, + op_desc.HasAttr("out_threshold"), + true, platform::errors::InvalidArgument( "must have out threshold in fc layers in int8 mode")); out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); @@ -156,9 +209,13 @@ class FcOpConverter : public OpConverter { out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out")); } nvinfer1::DimsHW nv_ksize(1, 1); - auto* fc_layer_int8 = - TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, - nv_ksize, weight.get(), bias.get()); + auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_, + Convolution, + *inputs, + n_output, + nv_ksize, + weight.get(), + bias.get()); fc_layer_int8->setName( ("fc_op_int8_conv1x1: Convolution (Output: " + output_name + ")") .c_str()); @@ -171,21 +228,29 @@ class FcOpConverter : public OpConverter { .c_str()); engine_->SetTensorDynamicRange(fc_after_reshape_int8->getOutput(0), out_scale); - nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( - engine_, Activation, *(fc_after_reshape_int8->getOutput(0)), - nvinfer1::ActivationType::kRELU); - RreplenishLayerAndOutput(relu_layer_int8, "relu_after_fc_shuffle", - {output_name}, test_mode); + nvinfer1::IActivationLayer* relu_layer_int8 = + TRT_ENGINE_ADD_LAYER(engine_, + Activation, + *(fc_after_reshape_int8->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_int8, + "relu_after_fc_shuffle", + {output_name}, + test_mode); } else { RreplenishLayerAndOutput(fc_after_reshape_int8, "fc_op_int8_reshape_after_fc: Shuffle", - {output_name}, test_mode); + {output_name}, + test_mode); } } else { // add fc layer - auto* fc_layer_float = - TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *inputs, n_output, - weight.get(), bias.get()); + auto* fc_layer_float = TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *inputs, + n_output, + weight.get(), + bias.get()); fc_layer_float->setName( ("fc_op_float: FullyConnected (Output: " + output_name + ")") .c_str()); @@ -195,14 +260,20 @@ class FcOpConverter : public OpConverter { fc_after_reshape_float->setName( ("float_reshape_after_fc: Shuffle (Output: " + output_name + ")") .c_str()); - nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER( - engine_, Activation, *(fc_after_reshape_float->getOutput(0)), - nvinfer1::ActivationType::kRELU); - RreplenishLayerAndOutput(relu_layer_float, "relu_after_fc_shuffle", - {output_name}, test_mode); + nvinfer1::IActivationLayer* relu_layer_float = + TRT_ENGINE_ADD_LAYER(engine_, + Activation, + *(fc_after_reshape_float->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_float, + "relu_after_fc_shuffle", + {output_name}, + test_mode); } else { - RreplenishLayerAndOutput(fc_after_reshape_float, "shuffle_after_fc", - {output_name}, test_mode); + RreplenishLayerAndOutput(fc_after_reshape_float, + "shuffle_after_fc", + {output_name}, + test_mode); } } }; @@ -251,15 +322,20 @@ class FcOpConverter : public OpConverter { if (enable_int8 || support_int8) { // add conv1x1 layer nvinfer1::DimsHW nv_ksize(1, 1); - auto* fc_layer_int8 = - TRT_ENGINE_ADD_LAYER(engine_, Convolution, *X, n_output, nv_ksize, - weight.get(), bias.get()); + auto* fc_layer_int8 = TRT_ENGINE_ADD_LAYER(engine_, + Convolution, + *X, + n_output, + nv_ksize, + weight.get(), + bias.get()); if (activation_type == "relu") { fc_layer_int8->setName( ("ernie_fc_op_int8: Convolution (Output: " + output_name + ")") .c_str()); PADDLE_ENFORCE_EQ( - op_desc.HasAttr("out_threshold"), true, + op_desc.HasAttr("out_threshold"), + true, platform::errors::InvalidArgument( "must have out threshold in fc layers in int8 mode")); float out_scale = 0; @@ -271,15 +347,20 @@ class FcOpConverter : public OpConverter { } engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), out_scale); - nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( - engine_, Activation, *(fc_layer_int8->getOutput(0)), - nvinfer1::ActivationType::kRELU); - RreplenishLayerAndOutput(relu_layer_int8, "relu_after_ernie_fc_int8", - {output_name}, test_mode); + nvinfer1::IActivationLayer* relu_layer_int8 = + TRT_ENGINE_ADD_LAYER(engine_, + Activation, + *(fc_layer_int8->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer_int8, + "relu_after_ernie_fc_int8", + {output_name}, + test_mode); } else { RreplenishLayerAndOutput(fc_layer_int8, "ernie_fc_op_int8: Convolution", - {output_name}, test_mode); + {output_name}, + test_mode); } } else { // add fc layer @@ -288,25 +369,30 @@ class FcOpConverter : public OpConverter { if (activation_type == "relu") { fc_layer_float->setName( ("ernie_fc_op_float: (Output: " + output_name + ")").c_str()); - nvinfer1::IActivationLayer* relu_layer_float = TRT_ENGINE_ADD_LAYER( - engine_, Activation, *(fc_layer_float->getOutput(0)), - nvinfer1::ActivationType::kRELU); + nvinfer1::IActivationLayer* relu_layer_float = + TRT_ENGINE_ADD_LAYER(engine_, + Activation, + *(fc_layer_float->getOutput(0)), + nvinfer1::ActivationType::kRELU); RreplenishLayerAndOutput(relu_layer_float, - "relu_after_ernie_fc_float", {output_name}, + "relu_after_ernie_fc_float", + {output_name}, test_mode); } else { - RreplenishLayerAndOutput(fc_layer_float, "ernie_fc_op_float", - {output_name}, test_mode); + RreplenishLayerAndOutput( + fc_layer_float, "ernie_fc_op_float", {output_name}, test_mode); } } } else { // need reshape input before and after fc PADDLE_ENFORCE_GT( - x_dim.nbDims, x_num_col_dims, + x_dim.nbDims, + x_num_col_dims, platform::errors::InvalidArgument( "Params and input dims mismatch. Paddle-TRT FC " "converter expects x_dim.nbDims > x_num_col_dims, but " "x_dim.nbDims : %d, x_num_col_dims : %d.", - x_dim.nbDims, x_num_col_dims)); + x_dim.nbDims, + x_num_col_dims)); auto* reshape_before_fc_layer = reshape_before_fc(X, x_dim, x_num_col_dims, output_name); auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index b89d62c97db13a88ef2bfc82720587eef7c02ba5..d30dc5eb35b15c2c842e6ab54b007a6810ac0071 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -1,8 +1,11 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See @@ -19,7 +22,8 @@ namespace tensorrt { class MultiheadMatMulOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { + const framework::Scope& scope, + bool test_mode) override { VLOG(3) << "convert a fluid multihead_mamul op to a corresponding tensorrt " "network structure"; framework::OpDesc op_desc(op, nullptr); @@ -49,8 +53,8 @@ class MultiheadMatMulOpConverter : public OpConverter { float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t); std::vector weight_data_tmp; weight_data_tmp.reserve(weight_t->numel()); - memcpy(weight_data_tmp.data(), weight_data, - weight_t->numel() * sizeof(float)); + memcpy( + weight_data_tmp.data(), weight_data, weight_t->numel() * sizeof(float)); // (hidden_in, 3, hidden_out) const auto& weight_dims = weight_t->dims(); @@ -98,14 +102,15 @@ class MultiheadMatMulOpConverter : public OpConverter { nvinfer1::ILayer* fc_layer = nullptr; float dp_probs = 1.0 / 127.0; nvinfer1::DimsHW nv_ksize(1, 1); - fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, - nv_ksize, weight, bias); + fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, Convolution, *input, n, nv_ksize, weight, bias); fc_layer->setName( ("Multihead: Convolution/FullyConnected: (Output: " + output_name + ")") .c_str()); PADDLE_ENFORCE_EQ( - op_desc.HasAttr("fc_out_threshold"), true, + op_desc.HasAttr("fc_out_threshold"), + true, platform::errors::InvalidArgument( "must have out_threshold in multihead layers in int8 mode")); float out_scale = @@ -119,13 +124,19 @@ class MultiheadMatMulOpConverter : public OpConverter { "CustomQKVToContextPluginDynamic", "3"); assert(creator != nullptr); std::vector fields{ - {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + {"hidden_size", + &hidden_out, + nvinfer1::PluginFieldType::kINT32, 1}, - {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, + {"num_heads", + &head_number, + nvinfer1::PluginFieldType::kINT32, 1}}; if (qkv2context_plugin_int8) { - fields.push_back({"dq_probs", &dp_probs, - nvinfer1::PluginFieldType::kFLOAT32, 1}); + fields.push_back({"dq_probs", + &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, + 1}); } nvinfer1::PluginFieldCollection* plugin_collection = static_cast(malloc( @@ -154,7 +165,8 @@ class MultiheadMatMulOpConverter : public OpConverter { engine_->GetITensor(engine_->network()->getInput(3)->getName()); engine_->SetTensorDynamicRange(max_seqlen_tensor, 1.0f); auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( - engine_, Shuffle, + engine_, + Shuffle, *const_cast(max_seqlen_tensor)); nvinfer1::Dims shape_dim; shape_dim.nbDims = 1; @@ -173,8 +185,11 @@ class MultiheadMatMulOpConverter : public OpConverter { // [3, head_number, head_size, hidden_in] -> [head_number, 3, // head_size, // hidden_in] - auto transpose_weight_v2 = [](const float* src, float* dst, int three, - int head_number, int head_size, + auto transpose_weight_v2 = [](const float* src, + float* dst, + int three, + int head_number, + int head_size, int hidden_in) { const int HH = head_size * hidden_in; for (int i = 0; i < three; ++i) { @@ -187,41 +202,47 @@ class MultiheadMatMulOpConverter : public OpConverter { } }; // [3, head_number, head_size] -> [head_number, 3, head_size] - auto transpose_bias_v2 = [](const float* src, float* dst, int N, - int H) { - for (int i = 0; i < 3; ++i) { - for (int n = 0; n < N; ++n) { - for (int h = 0; h < H; ++h) { - dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + auto transpose_bias_v2 = + [](const float* src, float* dst, int N, int H) { + for (int i = 0; i < 3; ++i) { + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + dst[n * 3 * H + i * H + h] = src[i * N * H + n * H + h]; + } + } } - } - } - }; - memcpy(weight_data_tmp.data(), weight_data, + }; + memcpy(weight_data_tmp.data(), + weight_data, weight_t->numel() * sizeof(float)); - transpose_weight_v2(weight_data_tmp.data(), weight_data, three, - head_number, head_size, hidden_in); + transpose_weight_v2(weight_data_tmp.data(), + weight_data, + three, + head_number, + head_size, + hidden_in); std::vector bias_data_tmp; bias_data_tmp.reserve(bias_t->numel()); - memcpy(bias_data_tmp.data(), bias_data, - bias_t->numel() * sizeof(float)); - transpose_bias_v2(bias_data_tmp.data(), bias_data, head_number, - head_size); + memcpy( + bias_data_tmp.data(), bias_data, bias_t->numel() * sizeof(float)); + transpose_bias_v2( + bias_data_tmp.data(), bias_data, head_number, head_size); nvinfer1::ILayer* fc_layer = nullptr; float dp_probs = 1.0 / 127.0; if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); - fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, - nv_ksize, weight, bias); + fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, Convolution, *input, n, nv_ksize, weight, bias); } else { - fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, n, - weight, bias); + fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, FullyConnected, *input, n, weight, bias); } if (op_desc.HasAttr("fc_out_threshold")) { - PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true, + PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), + true, platform::errors::InvalidArgument( "must have out threshold in multihead layers " "in int8 mode")); @@ -245,15 +266,21 @@ class MultiheadMatMulOpConverter : public OpConverter { int var_seqlen = 1; std::vector fields{ {"type_id", &type, nvinfer1::PluginFieldType::kINT32, 1}, - {"hidden_size", &hidden_out, nvinfer1::PluginFieldType::kINT32, + {"hidden_size", + &hidden_out, + nvinfer1::PluginFieldType::kINT32, 1}, {"num_heads", &head_number, nvinfer1::PluginFieldType::kINT32, 1}, {"has_mask", &has_mask, nvinfer1::PluginFieldType::kINT32, 1}, - {"var_seqlen", &var_seqlen, nvinfer1::PluginFieldType::kINT32, + {"var_seqlen", + &var_seqlen, + nvinfer1::PluginFieldType::kINT32, 1}}; if (qkv2context_plugin_int8) { - fields.push_back({"dq_probs", &dp_probs, - nvinfer1::PluginFieldType::kFLOAT32, 1}); + fields.push_back({"dq_probs", + &dp_probs, + nvinfer1::PluginFieldType::kFLOAT32, + 1}); } nvinfer1::PluginFieldCollection* plugin_collection = static_cast(malloc( @@ -274,7 +301,8 @@ class MultiheadMatMulOpConverter : public OpConverter { auto max_seqlen_tensor = engine_->GetITensor("mask_id"); auto* shuffle_layer = TRT_ENGINE_ADD_LAYER( - engine_, Shuffle, + engine_, + Shuffle, *const_cast(max_seqlen_tensor)); nvinfer1::Dims shape_dim; shape_dim.nbDims = 1; @@ -290,7 +318,8 @@ class MultiheadMatMulOpConverter : public OpConverter { } } else { PADDLE_ENFORCE_EQ( - input->getDimensions().nbDims, 3, + input->getDimensions().nbDims, + 3, platform::errors::InvalidArgument( "The Input dim of the MultiheadMatMul should be 3, " "but it's (%d) now.", @@ -309,20 +338,24 @@ class MultiheadMatMulOpConverter : public OpConverter { static_cast(bias_t->numel())}; // add shuffle before fc - nvinfer1::Dims reshape_before_fc_dim; - reshape_before_fc_dim.nbDims = 5; - reshape_before_fc_dim.d[0] = 0; - reshape_before_fc_dim.d[1] = 0; - reshape_before_fc_dim.d[2] = 0; - reshape_before_fc_dim.d[3] = 1; - reshape_before_fc_dim.d[4] = 1; + std::vector reshape_before_fc_shape_tensor; + nvinfer1::ITensor* input_shape_tensor = Shape(input); + + for (int i = 0; i < 5; i++) { + reshape_before_fc_shape_tensor.push_back(Add1DConstantLayer(1)); + } + for (int i = 0; i < 3; i++) { + reshape_before_fc_shape_tensor[i] = + GetEleTensorOfShape(input_shape_tensor, i); + } auto* reshape_before_fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); if (op_desc.HasAttr("Input_scale")) { engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), in_scale); } - reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setInput( + 1, *Concat(reshape_before_fc_shape_tensor)); reshape_before_fc_layer->setName( ("shuffle_before_multihead_mamul(Output: " + output_name + ")") .c_str()); @@ -331,18 +364,28 @@ class MultiheadMatMulOpConverter : public OpConverter { nvinfer1::ILayer* fc_layer = nullptr; if (op_desc.HasAttr("Input_scale")) { nvinfer1::DimsHW nv_ksize(1, 1); - fc_layer = TRT_ENGINE_ADD_LAYER( - engine_, Convolution, *reshape_before_fc_layer->getOutput(0), n, - nv_ksize, weight.get(), bias.get()); + fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Convolution, + *reshape_before_fc_layer->getOutput(0), + n, + nv_ksize, + weight.get(), + bias.get()); } else { - fc_layer = TRT_ENGINE_ADD_LAYER( - engine_, FullyConnected, *reshape_before_fc_layer->getOutput(0), - n, weight.get(), bias.get()); + fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, + FullyConnected, + *reshape_before_fc_layer->getOutput(0), + n, + weight.get(), + bias.get()); } if (op_desc.HasAttr("fc_out_threshold")) { PADDLE_ENFORCE_EQ( - op_desc.HasAttr("fc_out_threshold"), true, + op_desc.HasAttr("fc_out_threshold"), + true, platform::errors::InvalidArgument( "must have out threshold in multihead layers in int8 mode")); float out_scale = @@ -369,8 +412,8 @@ class MultiheadMatMulOpConverter : public OpConverter { with_fp16 = true; } plugin::DynamicPluginTensorRT* plugin = - new plugin::QkvToContextPluginDynamic(hidden_in, head_number, - head_size, scale, with_fp16); + new plugin::QkvToContextPluginDynamic( + hidden_in, head_number, head_size, scale, with_fp16); layer = engine_->AddDynamicPlugin(plugin_inputs.data(), 2, plugin); } } else { @@ -380,8 +423,8 @@ class MultiheadMatMulOpConverter : public OpConverter { "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " "the shape information to run the dynamic shape mode.")); } - RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name}, - test_mode); + RreplenishLayerAndOutput( + layer, "multihead_matmul", {output_name}, test_mode); } }; diff --git a/paddle/fluid/inference/tensorrt/convert/op_converter.h b/paddle/fluid/inference/tensorrt/convert/op_converter.h index f6ecf76d016759a2df05d8423635f0d560874ac2..d179e8bb34c16c60105a2b3f4db2e9256cf32ec1 100644 --- a/paddle/fluid/inference/tensorrt/convert/op_converter.h +++ b/paddle/fluid/inference/tensorrt/convert/op_converter.h @@ -47,14 +47,16 @@ class OpConverter { // test_mode: whether the instance executes in an unit test. void ConvertOp(const framework::proto::OpDesc& op, const std::unordered_set& parameters, - const framework::Scope& scope, TensorRTEngine* engine, + const framework::Scope& scope, + TensorRTEngine* engine, bool test_mode = false) { framework::OpDesc op_desc(op, nullptr); OpConverter* it{nullptr}; if (op_desc.Type() == "mul") { - PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL, + PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), + 1UL, platform::errors::InvalidArgument( "The input op mul's Input(\"Y\")." "size() should equal to 1, but reveceid " @@ -70,7 +72,8 @@ class OpConverter { "add", "mul", "sub", "div", "max", "min", "pow"}; static std::unordered_set add_weight_op_set{ "add", "mul", "sub", "div", "pow"}; - PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), 1UL, + PADDLE_ENFORCE_EQ(op_desc.Input("Y").size(), + 1UL, platform::errors::InvalidArgument( "The input op's Input(\"Y\")." "size() should equal to 1, but reveceid " @@ -81,64 +84,74 @@ class OpConverter { std::string Y = op_desc.Input("Y")[0]; if (parameters.count(Y)) { PADDLE_ENFORCE_GT( - add_weight_op_set.count(op_type), 0, + add_weight_op_set.count(op_type), + 0, platform::errors::Unimplemented("Unsupported elementwise type %s", op_type.c_str())); it = Registry::Global().Lookup("elementwise_" + op_type + "_weight"); PADDLE_ENFORCE_NOT_NULL( - it, platform::errors::Unimplemented( - "no OpConverter for optype [%s]", op_desc.Type())); + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); } else { PADDLE_ENFORCE_GT( - add_tensor_op_set.count(op_type), 0, + add_tensor_op_set.count(op_type), + 0, platform::errors::Unimplemented("Unsupported elementwise type %s", op_type.c_str())); it = Registry::Global().Lookup("elementwise_" + op_type + "_tensor"); } PADDLE_ENFORCE_NOT_NULL( - it, platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); } if (op_desc.Type() == "depthwise_conv2d") { it = Registry::Global().Lookup("conv2d"); PADDLE_ENFORCE_NOT_NULL( - it, platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); } if (op_desc.Type() == "depthwise_conv2d_transpose") { it = Registry::Global().Lookup("conv2d_transpose"); PADDLE_ENFORCE_NOT_NULL( - it, platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); } if (op_desc.Type() == "transpose2") { it = Registry::Global().Lookup("transpose"); PADDLE_ENFORCE_NOT_NULL( - it, platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); } if (op_desc.Type() == "flatten2") { it = Registry::Global().Lookup("flatten"); PADDLE_ENFORCE_NOT_NULL( - it, platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); } // reshape2 == reshape if (op_desc.Type() == "reshape2") { it = Registry::Global().Lookup("reshape"); PADDLE_ENFORCE_NOT_NULL( - it, platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); } if (!it) { it = Registry::Global().Lookup(op_desc.Type()); } PADDLE_ENFORCE_NOT_NULL( - it, platform::errors::Unimplemented("no OpConverter for optype [%s]", - op_desc.Type())); + it, + platform::errors::Unimplemented("no OpConverter for optype [%s]", + op_desc.Type())); it->SetEngine(engine); (*it)(op, scope, test_mode); @@ -214,7 +227,8 @@ class OpConverter { // the INetwork's inputs and outputs should specified in some other modules. void ConvertBlock(const framework::proto::BlockDesc& block, const std::unordered_set& parameters, - const framework::Scope& scope, TensorRTEngine* engine) { + const framework::Scope& scope, + TensorRTEngine* engine) { std::unique_lock lk(mut_); for (int i = 0; i < block.ops_size(); i++) { const auto& op = block.ops(i); @@ -224,20 +238,24 @@ class OpConverter { // The scope here should be inited with the parameter vars. void ConvertBlockToTRTEngine( - framework::BlockDesc* block_desc, const framework::Scope& scope, + framework::BlockDesc* block_desc, + const framework::Scope& scope, const std::vector& inputs, const std::unordered_set& parameters, - const std::vector& outputs, TensorRTEngine* engine) { + const std::vector& outputs, + TensorRTEngine* engine) { engine->InitNetwork(); bool all_dynamic_shape_set = true; for (auto& input : inputs) { if (parameters.count(input)) continue; auto* var = block_desc->FindVar(input); PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::NotFound("no variable called %s in block.", - input.c_str())); + var, + platform::errors::NotFound("no variable called %s in block.", + input.c_str())); PADDLE_ENFORCE_EQ( - var->GetType(), FluidDT::VarType_Type_LOD_TENSOR, + var->GetType(), + FluidDT::VarType_Type_LOD_TENSOR, platform::errors::InvalidArgument("TensorRT engine only takes " "LoDTensor as input")); auto var_shape = var->GetShape(); @@ -262,7 +280,8 @@ class OpConverter { } else { input_shape.push_back(min_input_shape[i]); // the i dimension should be same. - PADDLE_ENFORCE_EQ(min_input_shape[i], optim_input_shape[i], + PADDLE_ENFORCE_EQ(min_input_shape[i], + optim_input_shape[i], platform::errors::InvalidArgument( "The dim (%d) of the min_input_shape and " "optim_input_shape should be same.")); @@ -282,7 +301,8 @@ class OpConverter { Vec2TRT_Dims(var_shape, input)); } } - PADDLE_ENFORCE_EQ(all_dynamic_shape_set, true, + PADDLE_ENFORCE_EQ(all_dynamic_shape_set, + true, platform::errors::InvalidArgument( "some trt inputs dynamic shape info not set, " "check the INFO log above for more details.")); @@ -297,7 +317,8 @@ class OpConverter { // rank(result) = rank(input) nvinfer1::ITensor* Gather(nvinfer1::ITensor* input, - const std::vector indices, int axis = 0) { + const std::vector indices, + int axis = 0) { auto* indices_tensor = Add1DConstantLayer(indices, " "); auto* result = TRT_ENGINE_ADD_LAYER(engine_, Gather, *input, *indices_tensor, axis) @@ -326,8 +347,8 @@ class OpConverter { // Concat not make rank changed nvinfer1::ITensor* Concat(const std::vector& inputs, int axis = 0) { - auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Concatenation, inputs.data(), - inputs.size()); + auto* layer = TRT_ENGINE_ADD_LAYER( + engine_, Concatenation, inputs.data(), inputs.size()); if (axis != 0) layer->setAxis(axis); nvinfer1::ITensor* c = layer->getOutput(0); return c; @@ -335,48 +356,48 @@ class OpConverter { nvinfer1::ITensor* Sum(nvinfer1::ITensor* a, nvinfer1::ITensor* b) { nvinfer1::ITensor* c = - TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *a, *b, - nvinfer1::ElementWiseOperation::kSUM) + TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *a, *b, nvinfer1::ElementWiseOperation::kSUM) ->getOutput(0); return c; } nvinfer1::ITensor* Prod(nvinfer1::ITensor* a, nvinfer1::ITensor* b) { nvinfer1::ITensor* c = - TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *a, *b, - nvinfer1::ElementWiseOperation::kPROD) + TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *a, *b, nvinfer1::ElementWiseOperation::kPROD) ->getOutput(0); return c; } nvinfer1::ITensor* Min(nvinfer1::ITensor* a, nvinfer1::ITensor* b) { nvinfer1::ITensor* c = - TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *a, *b, - nvinfer1::ElementWiseOperation::kMIN) + TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *a, *b, nvinfer1::ElementWiseOperation::kMIN) ->getOutput(0); return c; } nvinfer1::ITensor* Max(nvinfer1::ITensor* a, nvinfer1::ITensor* b) { nvinfer1::ITensor* c = - TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *a, *b, - nvinfer1::ElementWiseOperation::kMAX) + TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *a, *b, nvinfer1::ElementWiseOperation::kMAX) ->getOutput(0); return c; } nvinfer1::ITensor* Sub(nvinfer1::ITensor* a, nvinfer1::ITensor* b) { nvinfer1::ITensor* c = - TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *a, *b, - nvinfer1::ElementWiseOperation::kSUB) + TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *a, *b, nvinfer1::ElementWiseOperation::kSUB) ->getOutput(0); return c; } nvinfer1::ITensor* Div(nvinfer1::ITensor* a, nvinfer1::ITensor* b) { nvinfer1::ITensor* c = - TRT_ENGINE_ADD_LAYER(engine_, ElementWise, *a, *b, - nvinfer1::ElementWiseOperation::kDIV) + TRT_ENGINE_ADD_LAYER( + engine_, ElementWise, *a, *b, nvinfer1::ElementWiseOperation::kDIV) ->getOutput(0); return c; } @@ -390,10 +411,14 @@ class OpConverter { // Get element tensor of 1D shape tensor nvinfer1::ITensor* GetEleTensorOfShape(nvinfer1::ITensor* shape_tensor, - int index, bool is_scalar = false) { + int index, + bool is_scalar = false) { auto* tensor = - TRT_ENGINE_ADD_LAYER(engine_, Gather, *shape_tensor, - *Add1DConstantLayer(index, " ", is_scalar), 0) + TRT_ENGINE_ADD_LAYER(engine_, + Gather, + *shape_tensor, + *Add1DConstantLayer(index, " ", is_scalar), + 0) ->getOutput(0); return tensor; } @@ -403,8 +428,8 @@ class OpConverter { const std::vector& weight_dims, const std::string& weight_name) { std::unique_ptr tmp_tensor(new framework::Tensor()); - int data_size = std::accumulate(weight_dims.begin(), weight_dims.end(), 1, - std::multiplies()); + int data_size = std::accumulate( + weight_dims.begin(), weight_dims.end(), 1, std::multiplies()); tmp_tensor->Resize({data_size}); auto* tmp_data = tmp_tensor->mutable_data(platform::CPUPlace()); for (int i = 0; i < data_size; i++) { @@ -489,7 +514,8 @@ class OpConverter { } void RreplenishLayerAndOutput( - nvinfer1::ILayer* layer, const std::string& layer_type, + nvinfer1::ILayer* layer, + const std::string& layer_type, const std::vector& output_tensor_names, bool test_mode = false) { size_t num_out = output_tensor_names.size(); diff --git a/paddle/fluid/inference/tensorrt/convert/slice_op.cc b/paddle/fluid/inference/tensorrt/convert/slice_op.cc index fa6f4889403654a03ddef0ce14db89481af70bef..bcf5e638126e260c912013a37c6002723ea6c1a6 100644 --- a/paddle/fluid/inference/tensorrt/convert/slice_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/slice_op.cc @@ -1,11 +1,8 @@ /* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -22,7 +19,8 @@ namespace tensorrt { class SliceOpConverter : public OpConverter { public: void operator()(const framework::proto::OpDesc& op, - const framework::Scope& scope, bool test_mode) override { + const framework::Scope& scope, + bool test_mode) override { // This OP is implemented by trt dynamic shpae plugin. // Dynamic shape plugin requires TRT version greater than 6.0. VLOG(4) << "convert slice op to tensorrt layer"; @@ -63,28 +61,118 @@ class SliceOpConverter : public OpConverter { } ends[i] = std::min(ends[i], input_dims.d[axes[i]]); PADDLE_ENFORCE_GT( - ends[i], starts[i], + ends[i], + starts[i], platform::errors::InvalidArgument( "Attr(ends) should be greater than attr(starts) in " "slice op. But received ends = %d, starts = %d.", - ends[i], starts[i])); + ends[i], + starts[i])); } } nvinfer1::ILayer* layer = nullptr; if (engine_->with_dynamic_shape()) { +#if IS_TRT_VERSION_GE(6000) + auto nchw_input_dims = input->getDimensions(); + nvinfer1::Dims trt_start_dims; + trt_start_dims.nbDims = nchw_input_dims.nbDims; + memset(trt_start_dims.d, 0, sizeof(int32_t) * nchw_input_dims.nbDims); + nvinfer1::Dims trt_size_dims = trt_start_dims; + nvinfer1::Dims trt_end_dims = trt_start_dims; + nvinfer1::Dims trt_step_dims = trt_start_dims; + for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1; + + // input : [N,C,H,W] + bool has_neg_indices = false; + for (size_t i = 0; i < axes.size(); i++) { + int trt_axis = axes[i]; + trt_start_dims.d[trt_axis] = starts[i]; + trt_end_dims.d[trt_axis] = ends[i]; + if (starts[i] < 0 || ends[i] < 0) has_neg_indices = true; + } + auto* shape_tensor = Shape(input); + auto* start_tensor = Add1DConstantLayer(trt_start_dims); + if (has_neg_indices) { + start_tensor = FixNegIndices(shape_tensor, start_tensor); + } + + std::vector end_vec_tensor; + for (int i = 0; i < trt_end_dims.nbDims; i++) { + end_vec_tensor.push_back(GetEleTensorOfShape(shape_tensor, i)); + } + + for (size_t i = 0; i < axes.size(); i++) { + int trt_axis = axes[i]; + if (ends[i] >= 0) { + end_vec_tensor[trt_axis] = Add1DConstantLayer(ends[i]); + } else { + end_vec_tensor[trt_axis] = + Sum(end_vec_tensor[trt_axis], Add1DConstantLayer(ends[i])); + } + } + +// CI failed in trt 6015 but success in 7134, may be a trt bug +#if IS_TRT_VERSION_GE(7134) + auto* size_tensor = + Sub(Min(Concat(end_vec_tensor), shape_tensor), start_tensor); +#else + auto* size_tensor = Sub(Concat(end_vec_tensor), start_tensor); +#endif + + layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims); + layer->setInput(1, *start_tensor); + layer->setInput(2, *size_tensor); + + if (decrease_axises.size() > 0) { + std::vector gather_indices; + for (int i = 0; i < trt_size_dims.nbDims; i++) { + if (decrease_axises.end() != + std::find(decrease_axises.begin(), decrease_axises.end(), i)) + continue; + gather_indices.push_back(i); + } + if (gather_indices.empty()) + gather_indices.push_back(decrease_axises[0]); + auto real_size_tensor = Gather(size_tensor, gather_indices); + layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *layer->getOutput(0)); + layer->setInput(1, *real_size_tensor); + } +#else bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); int decrease_axis = decrease_axises.size() == 0 ? -1 : decrease_axises[0]; plugin::SlicePluginDynamic* plugin = new plugin::SlicePluginDynamic( starts, ends, axes, decrease_axis, with_fp16); layer = engine_->AddDynamicPlugin(&input, 1, plugin); +#endif } else { +#if IS_TRT_VERSION_GE(6000) + auto chw_input_dims = input->getDimensions(); + nvinfer1::Dims trt_start_dims; + trt_start_dims.nbDims = chw_input_dims.nbDims; + memset(trt_start_dims.d, 0, sizeof(int32_t) * chw_input_dims.nbDims); + nvinfer1::Dims trt_size_dims = chw_input_dims; + nvinfer1::Dims trt_step_dims; + trt_step_dims.nbDims = chw_input_dims.nbDims; + for (int i = 0; i < trt_step_dims.nbDims; i++) trt_step_dims.d[i] = 1; + + // input : [C,H,W] + for (size_t i = 0; i < axes.size(); i++) { + int trt_axis = axes[i] - 1; + trt_start_dims.d[trt_axis] = starts[i]; + trt_size_dims.d[trt_axis] = ends[i] - starts[i]; + } + layer = TRT_ENGINE_ADD_LAYER( + engine_, Slice, *input, trt_start_dims, trt_size_dims, trt_step_dims); +#else bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); plugin::SlicePlugin* plugin = new plugin::SlicePlugin(starts, ends, axes, with_fp16); layer = engine_->AddPlugin(&input, 1, plugin); +#endif } RreplenishLayerAndOutput(layer, "slice", {output_name}, test_mode); } diff --git a/paddle/fluid/inference/tensorrt/engine.cc b/paddle/fluid/inference/tensorrt/engine.cc index 7f308fd3a04d5e33f93ea15380655f712763548a..82c51311a03d5fbffdc563d24d91e69f9643fa5f 100644 --- a/paddle/fluid/inference/tensorrt/engine.cc +++ b/paddle/fluid/inference/tensorrt/engine.cc @@ -49,7 +49,8 @@ void TensorRTEngine::InitNetwork() { optim_profiles_[i] = infer_builder_->createOptimizationProfile(); } -void TensorRTEngine::Execute(int batch_size, std::vector *buffers, +void TensorRTEngine::Execute(int batch_size, + std::vector *buffers, cudaStream_t stream) { freshDeviceId(); auto infer_context = context(); @@ -129,14 +130,32 @@ void TensorRTEngine::FreezeNetwork() { } #if IS_TRT_VERSION_GE(5122) - auto is_layer_int8 = [&](nvinfer1::ILayer *layer) -> bool { + auto layer_int8_fallback = [&](nvinfer1::ILayer *layer) -> bool { + if (layer->getType() == nvinfer1::LayerType::kSHAPE) { + return false; + } + bool all_int = true; + for (int j = 0; j < layer->getNbInputs(); j++) { + auto *temp_in = layer->getInput(j); + if (temp_in->getType() != nvinfer1::DataType::kINT32) { + all_int = false; + } + } + for (int j = 0; j < layer->getNbOutputs(); j++) { + auto *temp_out = layer->getOutput(j); + if (temp_out->getType() != nvinfer1::DataType::kINT32) { + all_int = false; + } + } + if (all_int) return false; + for (int j = 0; j < layer->getNbInputs(); j++) { auto *temp_in = layer->getInput(j); if (!temp_in->dynamicRangeIsSet()) { VLOG(1) << "Layer(Name: " << layer->getName() << ") is set to float32 because its input(" << temp_in->getName() << ") doesn't have dynamic range."; - return false; + return true; } } for (int j = 0; j < layer->getNbOutputs(); j++) { @@ -145,10 +164,10 @@ void TensorRTEngine::FreezeNetwork() { VLOG(1) << "Layer(Name: " << layer->getName() << ") is set to float32 because its output(" << temp_out->getName() << ") doesn't have dynamic range."; - return false; + return true; } } - return true; + return false; }; // If a layer's output is the network's output, or not all of its inputs // and outputs have scales, @@ -157,7 +176,7 @@ void TensorRTEngine::FreezeNetwork() { int layers_no_int8 = 0; for (int i = 0; i < network()->getNbLayers(); i++) { auto layer = network()->getLayer(i); - if (!is_layer_int8(layer)) { + if (layer_int8_fallback(layer)) { layer->setPrecision(nvinfer1::DataType::kFLOAT); ++layers_no_int8; } @@ -208,7 +227,8 @@ void TensorRTEngine::FreezeNetwork() { for (auto &input : min_input_shape_) { #if IS_TRT_VERSION_LT(7000) // trt6 will check all_of input > 0 - if (!(std::all_of(input.second.begin(), input.second.end(), + if (!(std::all_of(input.second.begin(), + input.second.end(), [](int x) { return x > 0; }) && std::all_of(max_input_shape_[input.first].begin(), max_input_shape_[input.first].end(), @@ -225,13 +245,16 @@ void TensorRTEngine::FreezeNetwork() { << ", opt: " << Vec2Str(optim_input_shape_[input.first]); optim_profiles_[i]->setDimensions( - input.first.c_str(), nvinfer1::OptProfileSelector::kMIN, + input.first.c_str(), + nvinfer1::OptProfileSelector::kMIN, Vec2TRT_Dims(input.second, input.first, true)); optim_profiles_[i]->setDimensions( - input.first.c_str(), nvinfer1::OptProfileSelector::kMAX, + input.first.c_str(), + nvinfer1::OptProfileSelector::kMAX, Vec2TRT_Dims(max_input_shape_[input.first], input.first, true)); optim_profiles_[i]->setDimensions( - input.first.c_str(), nvinfer1::OptProfileSelector::kOPT, + input.first.c_str(), + nvinfer1::OptProfileSelector::kOPT, Vec2TRT_Dims(optim_input_shape_[input.first], input.first, true)); } infer_builder_config_->addOptimizationProfile(optim_profiles_[i]); @@ -265,9 +288,10 @@ void TensorRTEngine::FreezeNetwork() { #endif PADDLE_ENFORCE_NOT_NULL( - infer_engine_, platform::errors::Fatal( - "Build TensorRT cuda engine failed! Please recheck " - "you configurations related to paddle-TensorRT.")); + infer_engine_, + platform::errors::Fatal( + "Build TensorRT cuda engine failed! Please recheck " + "you configurations related to paddle-TensorRT.")); binding_num_ = infer_engine_->getNbBindings(); // reset status for dynamic shape clone @@ -282,16 +306,19 @@ void TensorRTEngine::FreezeNetwork() { nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, nvinfer1::DataType dtype, const nvinfer1::Dims &dims) { - PADDLE_ENFORCE_EQ(network() != nullptr, true, + PADDLE_ENFORCE_EQ(network() != nullptr, + true, platform::errors::InvalidArgument( "The TRT network should be initialized first.")); auto *input = network()->addInput(name.c_str(), dtype, dims); PADDLE_ENFORCE_NOT_NULL( - input, platform::errors::InvalidArgument("Adding input %s failed in " - "TensorRT inference network. " - "Please recheck your input.", - name)); - PADDLE_ENFORCE_EQ(input->isNetworkInput(), true, + input, + platform::errors::InvalidArgument("Adding input %s failed in " + "TensorRT inference network. " + "Please recheck your input.", + name)); + PADDLE_ENFORCE_EQ(input->isNetworkInput(), + true, platform::errors::InvalidArgument( "Input %s is not the input of TRT inference network. " "Please recheck your input.", @@ -300,22 +327,26 @@ nvinfer1::ITensor *TensorRTEngine::DeclareInput(const std::string &name, return input; } -void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset, +void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, + int offset, const std::string &name) { auto *output = layer->getOutput(offset); SetITensor(name, output); PADDLE_ENFORCE_NOT_NULL( - output, platform::errors::InvalidArgument( - "The output %s of TRT engine should not be null.", name)); + output, + platform::errors::InvalidArgument( + "The output %s of TRT engine should not be null.", name)); output->setName(name.c_str()); - PADDLE_ENFORCE_EQ(output->isNetworkInput(), false, + PADDLE_ENFORCE_EQ(output->isNetworkInput(), + false, platform::errors::InvalidArgument( "The output %s of TRT engine should not be the input " "of the network at the same time.", name)); network()->markOutput(*output); PADDLE_ENFORCE_EQ( - output->isNetworkOutput(), true, + output->isNetworkOutput(), + true, platform::errors::InvalidArgument( "The output %s of TRT engine should be the output of the network.", name)); @@ -324,10 +355,12 @@ void TensorRTEngine::DeclareOutput(const nvinfer1::ILayer *layer, int offset, void TensorRTEngine::DeclareOutput(const std::string &name) { auto *output = TensorRTEngine::GetITensor(name); PADDLE_ENFORCE_NOT_NULL( - output, platform::errors::InvalidArgument( - "The output %s of TRT engine should not be null.", name)); + output, + platform::errors::InvalidArgument( + "The output %s of TRT engine should not be null.", name)); output->setName(name.c_str()); - PADDLE_ENFORCE_EQ(output->isNetworkInput(), false, + PADDLE_ENFORCE_EQ(output->isNetworkInput(), + false, platform::errors::InvalidArgument( "The output %s of TRT engine should not be the input " "of the network at the same time.", @@ -338,17 +371,20 @@ void TensorRTEngine::DeclareOutput(const std::string &name) { void TensorRTEngine::SetITensor(const std::string &name, nvinfer1::ITensor *tensor) { PADDLE_ENFORCE_NOT_NULL( - tensor, platform::errors::InvalidArgument( - "Tensor named %s of TRT engine should not be null.", name)); + tensor, + platform::errors::InvalidArgument( + "Tensor named %s of TRT engine should not be null.", name)); PADDLE_ENFORCE_EQ( - 0, itensor_map_.count(name), + 0, + itensor_map_.count(name), platform::errors::InvalidArgument( "Tensor named %s of TRT engine should not be duplicated", name)); itensor_map_[name] = tensor; } nvinfer1::ITensor *TensorRTEngine::GetITensor(const std::string &name) { - PADDLE_ENFORCE_EQ(itensor_map_.count(name), true, + PADDLE_ENFORCE_EQ(itensor_map_.count(name), + true, platform::errors::NotFound( "Tensor named %s is not found in TRT engine", name)); return itensor_map_[name]; @@ -365,15 +401,16 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, std::string splitter = "__"; std::string name_with_suffix = name + splitter + name_suffix; platform::CPUPlace cpu_place; - PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), 0, + PADDLE_ENFORCE_EQ(weight_map.count(name_with_suffix), + 0, platform::errors::AlreadyExists( "The weight named %s is set into the weight map " "twice in TRT OP converter.", name_with_suffix)); weight_map[name_with_suffix].reset(new framework::Tensor()); weight_map[name_with_suffix]->Resize(weight_tensor->dims()); - paddle::framework::TensorCopySync(*weight_tensor, cpu_place, - weight_map[name_with_suffix].get()); + paddle::framework::TensorCopySync( + *weight_tensor, cpu_place, weight_map[name_with_suffix].get()); float *weight_data = weight_map[name_with_suffix]->mutable_data(cpu_place); name_suffix_counter += 1; @@ -383,21 +420,24 @@ float *TensorRTEngine::GetWeightCPUData(const std::string &name, int TensorRTEngine::GetRuntimeBatch() { return runtime_batch_; } nvinfer1::IPluginV2Layer *TensorRTEngine::AddPlugin( - nvinfer1::ITensor *const *inputs, int num_inputs, + nvinfer1::ITensor *const *inputs, + int num_inputs, plugin::PluginTensorRT *plugin) { owned_plugin_.emplace_back(plugin); return network()->addPluginV2(inputs, num_inputs, *plugin); } nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2Ext( - nvinfer1::ITensor *const *inputs, int num_inputs, + nvinfer1::ITensor *const *inputs, + int num_inputs, plugin::PluginTensorRTV2Ext *plugin) { owned_plugin_v2ext_.emplace_back(plugin); return network()->addPluginV2(inputs, num_inputs, *plugin); } nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2IOExt( - nvinfer1::ITensor *const *inputs, int num_inputs, + nvinfer1::ITensor *const *inputs, + int num_inputs, nvinfer1::IPluginV2IOExt *plugin) { owned_plugin_v2ioext_.emplace_back(plugin); return network()->addPluginV2(inputs, num_inputs, *plugin); @@ -406,10 +446,12 @@ nvinfer1::IPluginV2Layer *TensorRTEngine::AddPluginV2IOExt( void TensorRTEngine::freshDeviceId() { int count; cudaGetDeviceCount(&count); - PADDLE_ENFORCE_LT(device_id_, count, + PADDLE_ENFORCE_LT(device_id_, + count, platform::errors::OutOfRange( "Device id %d exceeds the current device count: %d.", - device_id_, count)); + device_id_, + count)); platform::SetDeviceId(device_id_); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py index 003c84c4c5ab069a6ee47e09d495ca3dbb4fc74d..76a84c77122c5d6c59140b59c6c15c42dda4a64d 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_slice.py @@ -62,7 +62,7 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): for axes in [[0, 1], [1, 3], [2, 3]]: for starts in [[0, 1]]: - for ends in [[2, 2], [5, 5]]: + for ends in [[2, 2], [5, 5], [1, -1]]: for decrease_axis in [[], [1], [2], [-1], [-100]]: for infer_flags in [[-1]]: dics = [{ @@ -118,10 +118,6 @@ class TrtConvertSliceTest(TrtLayerAutoScanTest): return 0, 3 if dynamic_shape == False and len(attrs[0]["decrease_axis"]) != 0: return 0, 3 - if dynamic_shape: - for i in range(len(attrs[0]["starts"])): - if attrs[0]["starts"][i] < 0 or attrs[0]["ends"][i] < 0: - return 0, 3 if not dynamic_shape: for x in attrs[0]["axes"]: if x == 0: