diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index bbacad7b8c25f58ac2fe3683a4c805c3715f565e..3f35d225dc2f648169d9605369664aeffbbc6b4f 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -147,11 +147,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( std::set output_names; std::set output_names_with_id; - std::vector origin_output_dims; + std::map origin_name_output_dims; for (auto *x : node->outputs) { output_names.insert(x->Name()); output_names_with_id.insert(x->Name() + std::to_string(x->id())); - origin_output_dims.push_back(x->Var()->GetShape().size()); + origin_name_output_dims[x->Name()] = x->Var()->GetShape().size(); } std::unordered_map output_name_map; @@ -195,9 +195,11 @@ void TensorRtSubgraphPass::CreateTensorRTOp( // output_mapping help us copy the data from the renamed ITensor // to Tensor. std::vector output_mapping; + std::vector renamed_output_dims; for (auto name : output_names) { PADDLE_ENFORCE(output_name_map.count(name) != 0); output_mapping.push_back(output_name_map[name]); + renamed_output_dims.push_back(origin_name_output_dims[name]); } PADDLE_ENFORCE(!output_mapping.empty()); PADDLE_ENFORCE(!block_desc.Proto()->vars().empty(), @@ -217,7 +219,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp( op_desc->SetAttr("workspace_size", Get("workspace_size")); op_desc->SetAttr("gpu_id", Get("gpu_device_id")); op_desc->SetAttr("output_name_mapping", output_mapping); - op_desc->SetAttr("origin_output_dims", origin_output_dims); + op_desc->SetAttr("origin_output_dims", renamed_output_dims); op_desc->SetAttr("parameters", params); // we record all inputs' shapes in attr to check if they are consistent diff --git a/paddle/fluid/inference/tensorrt/convert/fc_op.cc b/paddle/fluid/inference/tensorrt/convert/fc_op.cc index 8a2da41afc8c57030cbc8113a98221f8a03b69d9..fa4c6b5ea9114f565f2fd1ae6ad12d0771d2c085 100644 --- a/paddle/fluid/inference/tensorrt/convert/fc_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/fc_op.cc @@ -133,7 +133,69 @@ class FcOpConverter : public OpConverter { static_cast(bias_num)}; if (engine_->with_dynamic_shape()) { - regist_fc(X, n_output, weight, bias); + // not NCHW layout, but NLP layout with added 'x 1 x 1' + auto x_dim = X->getDimensions(); + if (x_dim.nbDims == 3 || x_dim.nbDims == 2) { + auto output_name = op_desc.Output("Out").front(); + // add shuffle before fc + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = x_dim.nbDims + 2; + for (int i = 0; i < x_dim.nbDims; i++) { + reshape_before_fc_dim.d[i] = 0; + } + reshape_before_fc_dim.d[x_dim.nbDims] = 1; + reshape_before_fc_dim.d[x_dim.nbDims + 1] = 1; + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setName( + ("shuffle_before_fc(Output: " + output_name + ")").c_str()); + + // add fc layer + auto* fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, FullyConnected, *reshape_before_fc_layer->getOutput(0), + n_output, weight.get(), bias.get()); + fc_layer->setName(("fc_layer(Output: " + output_name + ")").c_str()); + + // add shuffle after fc + nvinfer1::Dims reshape_after_fc_dim; + if (x_dim.nbDims == 3) { + if (x_num_col_dims == 2) { + reshape_after_fc_dim.nbDims = 3; + reshape_after_fc_dim.d[0] = 0; + reshape_after_fc_dim.d[1] = 0; + reshape_after_fc_dim.d[2] = 0; + } else { + reshape_after_fc_dim.nbDims = 2; + reshape_after_fc_dim.d[0] = 0; + auto dim = fc_layer->getOutput(0)->getDimensions(); + reshape_after_fc_dim.d[1] = dim.d[1] * dim.d[2]; + } + // x_dim.nbDims == 2 + } else { + reshape_after_fc_dim.nbDims = 2; + reshape_after_fc_dim.d[0] = 0; + reshape_after_fc_dim.d[1] = 0; + } + auto* reshape_after_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *fc_layer->getOutput(0)); + reshape_after_fc_layer->setReshapeDimensions(reshape_after_fc_dim); + + if (activation_type == "relu") { + reshape_after_fc_layer->setName( + ("shuffle_after_fc(Output: " + output_name + ")").c_str()); + nvinfer1::IActivationLayer* relu_layer = TRT_ENGINE_ADD_LAYER( + engine_, Activation, *(reshape_after_fc_layer->getOutput(0)), + nvinfer1::ActivationType::kRELU); + RreplenishLayerAndOutput(relu_layer, "relu_after_fc_shuffle", + {output_name}, test_mode); + } else { + RreplenishLayerAndOutput(reshape_after_fc_layer, "shuffle_after_fc", + {output_name}, test_mode); + } + } else { + regist_fc(X, n_output, weight, bias); + } return; } // in order to handle situations in NLP models(input dims < 3, @@ -143,12 +205,6 @@ class FcOpConverter : public OpConverter { auto input_d = X->getDimensions().d; int reshape_dim3[3] = {0}; int reshape_dim4[4] = {0}; - PADDLE_ENFORCE_EQ( - x_num_col_dims == 1 || x_num_col_dims == 2, true, - platform::errors::InvalidArgument( - "Wrong x_num_col_dims param of op mul. Paddle-TRT FC converter " - "expects x_num_col_dims is either 1 or 2, but got %d", - x_num_col_dims)); PADDLE_ENFORCE_LE(x_num_col_dims, input_dims, platform::errors::InvalidArgument( "Params and input dims mismatch. Paddle-TRT FC " diff --git a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc index 82c01490d9e1cfe2405ef5e6fe2d2dbd7850d201..b016cf418bc379c0edb1c4ceac570286cf45c6c2 100644 --- a/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc @@ -8,8 +8,8 @@ http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See +the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" @@ -28,7 +28,6 @@ class MultiheadMatMulOpConverter : public OpConverter { "network structure"; framework::OpDesc op_desc(op, nullptr); // Declare inputs - // Shouble be a 5 dims tensor. auto* input = engine_->GetITensor(op_desc.Input("Input").front()); // fc weights and fc bias @@ -69,6 +68,7 @@ class MultiheadMatMulOpConverter : public OpConverter { int head_number = boost::get(op_desc.GetAttr("head_number")); nvinfer1::ILayer* layer = nullptr; + auto output_name = op_desc.Output("Out")[0]; if (engine_->with_dynamic_shape()) { if (engine_->use_oss()) { @@ -170,6 +170,12 @@ class MultiheadMatMulOpConverter : public OpConverter { plugin_inputs.data(), plugin_inputs.size(), *plugin); layer = plugin_layer; } else { + PADDLE_ENFORCE_EQ( + input->getDimensions().nbDims, 3, + platform::errors::InvalidArgument( + "The Input dim of the MultiheadMatMul should be 3, " + "but it's (%d) now.", + input->getDimensions().nbDims)); // transpose weight_data from m * n to n * m auto* input_bias_qk = engine_->GetITensor(op_desc.Input("BiasQK").front()); @@ -183,15 +189,37 @@ class MultiheadMatMulOpConverter : public OpConverter { static_cast(bias_data), static_cast(bias_t->numel())}; - auto* fc_layer = TRT_ENGINE_ADD_LAYER(engine_, FullyConnected, *input, - n, weight.get(), bias.get()); - auto* fc_out = fc_layer->getOutput(0); + // add shuffle before fc + nvinfer1::Dims reshape_before_fc_dim; + reshape_before_fc_dim.nbDims = 5; + reshape_before_fc_dim.d[0] = 0; + reshape_before_fc_dim.d[1] = 0; + reshape_before_fc_dim.d[2] = 0; + reshape_before_fc_dim.d[3] = 1; + reshape_before_fc_dim.d[4] = 1; + auto* reshape_before_fc_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); + reshape_before_fc_layer->setReshapeDimensions(reshape_before_fc_dim); + reshape_before_fc_layer->setName( + ("shuffle_before_multihead_mamul(Output: " + output_name + ")") + .c_str()); + + // add layer fc + auto* fc_layer = TRT_ENGINE_ADD_LAYER( + engine_, FullyConnected, *reshape_before_fc_layer->getOutput(0), n, + weight.get(), bias.get()); + fc_layer->setName( + ("multihead_mamul_fc(Output: " + output_name + ")").c_str()); + + // no need to add shuffle after fc, just change it in + // QkvToContextPluginDynamic + // add qkv to context int head_size = hidden_out / head_number; float scale = boost::get(op_desc.GetAttr("alpha")); std::vector plugin_inputs; - plugin_inputs.push_back(fc_out); + plugin_inputs.push_back(fc_layer->getOutput(0)); plugin_inputs.push_back(input_bias_qk); bool with_fp16 = engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); @@ -207,7 +235,6 @@ class MultiheadMatMulOpConverter : public OpConverter { "You can use the config.SetTRTDynamicShapeInfo(...) interface to set " "the shape information to run the dynamic shape mode.")); } - auto output_name = op_desc.Output("Out")[0]; RreplenishLayerAndOutput(layer, "multihead_matmul", {output_name}, test_mode); #else diff --git a/paddle/fluid/inference/tensorrt/convert/scale_op.cc b/paddle/fluid/inference/tensorrt/convert/scale_op.cc index 10f238efac385fe7dd7d2d47842f2ccf5e1a73a8..3b7da39c94101989bd380d993afb46f15bfd7c2f 100644 --- a/paddle/fluid/inference/tensorrt/convert/scale_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/scale_op.cc @@ -48,6 +48,8 @@ class ScaleOpConverter : public OpConverter { return tmp_data; }; + int dynamic_shape_offset = engine_->with_dynamic_shape() ? 1 : 0; + float* bias_ptr = create_weights(bias, "bias"); float* scale_ptr = create_weights(scale, "scale"); @@ -60,19 +62,22 @@ class ScaleOpConverter : public OpConverter { nvinfer1::ILayer* layer = nullptr; auto input_dim = input->getDimensions(); - PADDLE_ENFORCE_GE(input_dim.nbDims, 3, - platform::errors::Fatal( - "Paddle-TRT scale mode only support dimension >= 3")); nvinfer1::IShuffleLayer* expand_layer = nullptr; nvinfer1::IShuffleLayer* squeeze_layer = nullptr; - if (input_dim.nbDims == 3) { - // TensorRT scale layer is not supporting input dims < 4 when using - // explicit batch + if (input_dim.nbDims < 3 + dynamic_shape_offset) { + nvinfer1::Dims expand_shape; + expand_shape.nbDims = 3 + dynamic_shape_offset; + for (int i = 0; i < 3 + dynamic_shape_offset; i++) { + if (i < input_dim.nbDims) { + expand_shape.d[i] = input_dim.d[i] < 0 ? 0 : input_dim.d[i]; + } else { + expand_shape.d[i] = 1; + } + } expand_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); - nvinfer1::Dims4 target_shape(0, 0, 0, 1); // expand 1 dims - expand_layer->setReshapeDimensions(target_shape); + expand_layer->setReshapeDimensions(expand_shape); input = expand_layer->getOutput(0); } @@ -94,13 +99,15 @@ class ScaleOpConverter : public OpConverter { PADDLE_ENFORCE_EQ(layer != nullptr, true, platform::errors::Fatal("Create scale layer failed.")); - if (input_dim.nbDims == 3) { - // TensorRT scale layer is not supporting input dims < 4 when using - // explicit batch + if (input_dim.nbDims < 3 + dynamic_shape_offset) { + nvinfer1::Dims squeeze_shape; + squeeze_shape.nbDims = input_dim.nbDims; + for (int i = 0; i < squeeze_shape.nbDims; i++) { + squeeze_shape.d[i] = input_dim.d[i] < 0 ? 0 : input_dim.d[i]; + } squeeze_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *(layer->getOutput(0))); - nvinfer1::Dims3 target_shape(0, 0, 0); // expand 1 dims - squeeze_layer->setReshapeDimensions(target_shape); + squeeze_layer->setReshapeDimensions(squeeze_shape); layer = static_cast(squeeze_layer); } RreplenishLayerAndOutput(layer, "scale", {out_name}, test_mode); diff --git a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc index 05c9c0ec5da9a80a0afee1780daa23f78dde4e9d..6b663071b49cbe3f27ba2b8e20d34a65937cf9f1 100644 --- a/paddle/fluid/inference/tensorrt/convert/softmax_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/softmax_op.cc @@ -50,6 +50,7 @@ class SoftMaxOpConverter : public OpConverter { uint32_t axes = std::max(0, input_dims - 3); // TODO(cryoco): Poor workaround. Fix padded dims problem when TRT layers // support Nd. + // Tips: Dynammic shape alreay fixes. int padded_dims = 0; int explicit_batch = 0; if (engine_->with_dynamic_shape()) explicit_batch = 1; @@ -61,16 +62,16 @@ class SoftMaxOpConverter : public OpConverter { } } if (!engine_->with_dynamic_shape()) { - if (axis == -1) { - axes = input_dims - 1 - padded_dims; + if (axis < 0) { + axes = input_dims + axis - padded_dims; } else { - axes = axis; + axes = axis - 1; } } else { - if (axis == -1) { - axes = input_dims - 1 - padded_dims; + if (axis < 0) { + axes = input_dims + axis; } else { - axes = axis + 1; + axes = axis; } } layer->setAxes(1 << axes); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 34d6881560a4a28719106b856b372f8afbed33ed..c34749e6d32c27e6233b9abb4e45bbb08c32b390 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -131,6 +131,18 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc, } } } + if (op_type == "fc" || op_type == "mul") { + const int x_num_col_dims = + desc.HasAttr("x_num_col_dims") + ? boost::get(desc.GetAttr("x_num_col_dims")) + : (desc.HasAttr("in_num_col_dims") + ? boost::get(desc.GetAttr("in_num_col_dims")) + : 1); + if (x_num_col_dims != 1 && x_num_col_dims != 2) { + return false; + } + } + if ((*teller)(op_type, desc, use_no_calib_int8)) return true; } return false; diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu index 0ec803fe64afadd970777e3b0d0ab5d37fcc4d22..9cb703cfe0fd22baf7b75a6b2a05e8fac299906a 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.cu @@ -127,9 +127,14 @@ int ElementWisePlugin::enqueue(int batch_size, const void *const *inputs, int ElementwisePluginDynamic::initialize() { return 0; } -size_t ElementwisePluginDynamic::getSerializationSize() const { return 0; } +size_t ElementwisePluginDynamic::getSerializationSize() const { + return SerializedSize(type_.c_str()) + SerializedSize(axis_); +} -void ElementwisePluginDynamic::serialize(void *buffer) const {} +void ElementwisePluginDynamic::serialize(void *buffer) const { + SerializeValue(&buffer, type_.c_str()); + SerializeValue(&buffer, axis_); +} nvinfer1::DimsExprs ElementwisePluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, diff --git a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h index e37511868d88f600a733df4ebb478e74a385be1b..49212aae9aa90dace4b2824cb9b6f0a7b6127f31 100644 --- a/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h +++ b/paddle/fluid/inference/tensorrt/plugin/elementwise_op_plugin.h @@ -92,7 +92,12 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT { public: explicit ElementwisePluginDynamic(const std::string& type, int axis) : type_(type), axis_(axis) {} - ElementwisePluginDynamic(void const* serialData, size_t serialLength) {} + ElementwisePluginDynamic(void const* serialData, size_t serialLength) { + const char* elementwise_type; + DeserializeValue(&serialData, &serialLength, &elementwise_type); + type_ = std::string(elementwise_type); + DeserializeValue(&serialData, &serialLength, &axis_); + } nvinfer1::IPluginV2DynamicExt* clone() const override { return new ElementwisePluginDynamic(type_, axis_); } @@ -138,6 +143,46 @@ class ElementwisePluginDynamic : public DynamicPluginTensorRT { std::string type_; int axis_; }; + +class ElementwisePluginV2Creator : public nvinfer1::IPluginCreator { + public: + ElementwisePluginV2Creator() {} + const char* getPluginName() const override { return "elementwise_plugin"; } + + const char* getPluginVersion() const override { return "1"; } + + const nvinfer1::PluginFieldCollection* getFieldNames() override { + return &field_collection_; + } + + nvinfer1::IPluginV2* createPlugin( + const char* name, const nvinfer1::PluginFieldCollection* fc) override { + return nullptr; + } + + nvinfer1::IPluginV2* deserializePlugin(const char* name, + const void* serial_data, + size_t serial_length) override { + auto plugin = new ElementwisePluginDynamic(serial_data, serial_length); + return plugin; + } + + void setPluginNamespace(const char* lib_namespace) override { + plugin_namespace_ = lib_namespace; + } + + const char* getPluginNamespace() const override { + return plugin_namespace_.c_str(); + } + + private: + std::string plugin_namespace_; + std::string plugin_name_; + nvinfer1::PluginFieldCollection field_collection_{0, nullptr}; + std::vector plugin_attributes_; +}; + +REGISTER_TRT_PLUGIN_V2(ElementwisePluginV2Creator); #endif } // namespace plugin diff --git a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu index 30667514ac83a466fb7c131e66286617a62a778e..81c70d067d7e50949ec33b19c3456046d7e62f14 100644 --- a/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/emb_eltwise_layernorm_plugin.cu @@ -182,12 +182,10 @@ nvinfer1::DimsExprs EmbEltwiseLayernormPluginDynamic::getOutputDimensions( "but it's (%d)", output_index)); nvinfer1::DimsExprs ret; - ret.nbDims = 5; + ret.nbDims = 3; ret.d[0] = inputs[0].d[0]; ret.d[1] = inputs[0].d[1]; ret.d[2] = expr_builder.constant(hidden_size_); - ret.d[3] = expr_builder.constant(1); - ret.d[4] = expr_builder.constant(1); return ret; } diff --git a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu index f13d2fedac166b576f96a5a0ae0add4851b7739a..2f7fc3bce4cb469bbcdf4d83c1b822e37c970635 100644 --- a/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu @@ -158,12 +158,10 @@ nvinfer1::DimsExprs QkvToContextPluginDynamic::getOutputDimensions( "it has (%d) inputs", nb_inputs)); nvinfer1::DimsExprs ret; - ret.nbDims = 5; + ret.nbDims = 3; ret.d[0] = inputs[0].d[0]; ret.d[1] = inputs[0].d[1]; ret.d[2] = expr_builder.constant(head_size_ * head_number_); - ret.d[3] = expr_builder.constant(1); - ret.d[4] = expr_builder.constant(1); return ret; } diff --git a/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu index 6b2b93ba2230faa3355075252a8e94db65f8df28..b3d4e16d340a0b0b284495db5e3a1858af7a8439 100644 --- a/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/skip_layernorm_op_plugin.cu @@ -43,11 +43,6 @@ int SkipLayerNormPluginDynamic::initialize() { nvinfer1::DimsExprs SkipLayerNormPluginDynamic::getOutputDimensions( int output_index, const nvinfer1::DimsExprs *inputs, int nb_inputs, nvinfer1::IExprBuilder &expr_builder) { - PADDLE_ENFORCE_EQ( - inputs[0].nbDims, 5, - platform::errors::InvalidArgument( - "The Input dim of the SkipLayernorm should be 5, but it's (%d) now.", - inputs[0].nbDims)); return inputs[0]; } diff --git a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu index 250b944652b93c54ae9587271256b42c6e1bc6b7..fdb14f9ceaf29fe90cd756b77e7c5afff2296f44 100644 --- a/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu +++ b/paddle/fluid/inference/tensorrt/plugin/special_slice_plugin.cu @@ -62,6 +62,8 @@ nvinfer1::DimsExprs SpecialSlicePluginDynamic::getOutputDimensions( output.d[1] = one; output.d[0] = expr_builder.operation(nvinfer1::DimensionOperation::kSUB, *inputs[1].d[0], *one); + // remove padding 1 + output.nbDims -= 2; return output; } diff --git a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc index 3916cf361c4b87602d9abc996788566da0488bbf..75e085efb69f5d56c2ce6bd1dd3cd7923e7394d5 100644 --- a/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc +++ b/paddle/fluid/inference/tests/api/trt_dynamic_shape_transformer_prune_test.cc @@ -126,7 +126,7 @@ void trt_ernie(bool with_fp16, std::vector result) { run(config, &out_data); for (size_t i = 0; i < out_data.size(); i++) { - EXPECT_NEAR(result[i], out_data[i], 1e-4); + EXPECT_NEAR(result[i], out_data[i], 3e-3); } } diff --git a/paddle/fluid/operators/softmax_cudnn_op.cu.cc b/paddle/fluid/operators/softmax_cudnn_op.cu.cc index ad3e5543f10ae05865565110ba2231c897c205b8..94e54266f0f922efef5ea4a1b23338b6ce02d131 100644 --- a/paddle/fluid/operators/softmax_cudnn_op.cu.cc +++ b/paddle/fluid/operators/softmax_cudnn_op.cu.cc @@ -12,60 +12,90 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/softmax_op.h" +#include "paddle/fluid/platform/cudnn_desc.h" +#include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { namespace operators { +using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; +using DataLayout = platform::DataLayout; using Tensor = framework::Tensor; +static inline int SizeOutAxis(const int axis, DDim dims) { + int size = 1; + for (int i = axis + 1; i < dims.size(); i++) { + size *= dims[i]; + } + return size; +} + template class SoftmaxCUDNNKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* X = context.Input("X"); - auto* Out = context.Output("Out"); - - // allocate memory on device. - Out->mutable_data(context.GetPlace()); - - auto dims = X->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_x; - framework::LoDTensor flattened_out; - flattened_x.ShareDataWith(*X).Resize(flattened_dims); - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - - math::SoftmaxCUDNNFunctor()( - context.template device_context(), - &flattened_x, &flattened_out); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + out->mutable_data(ctx.GetPlace()); + auto* out_data = out->data(); + + auto dims = x->dims(); + const int rank = dims.size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + const int dim = dims[axis]; + const int N = SizeToAxis(axis, dims); + const int D = SizeOutAxis(axis, dims); + + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + DataLayout layout = DataLayout::kNCHW; + cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); + + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward( + handle, CUDNN_SOFTMAX_ACCURATE, mode, + platform::CudnnDataType::kOne(), desc_, x->data(), + platform::CudnnDataType::kZero(), desc_, out_data)); } }; template class SoftmaxGradCUDNNKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto* Out = context.Input("Out"); - auto* dOut = context.Input(framework::GradVarName("Out")); - auto* dX = context.Output(framework::GradVarName("X")); - - // allocate memory on device. - dX->mutable_data(context.GetPlace()); - - auto dims = Out->dims(); - auto flattened_dims = framework::flatten_to_2d(dims, dims.size() - 1); - framework::LoDTensor flattened_out; - framework::LoDTensor flattened_d_out; - framework::LoDTensor flattened_d_x; - flattened_out.ShareDataWith(*Out).Resize(flattened_dims); - flattened_d_out.ShareDataWith(*dOut).Resize(flattened_dims); - flattened_d_x.ShareDataWith(*dX).Resize(flattened_dims); - - math::SoftmaxGradCUDNNFunctor()( - context.template device_context(), - &flattened_out, &flattened_d_out, &flattened_d_x); + void Compute(const framework::ExecutionContext& ctx) const override { + auto* out = ctx.Input("Out"); + auto* dout = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + auto* dx_data = dx->data(); + + auto dims = out->dims(); + const int rank = dims.size(); + const int axis = CanonicalAxis(ctx.Attr("axis"), rank); + const int dim = dims[axis]; + const int N = SizeToAxis(axis, dims); + const int D = SizeOutAxis(axis, dims); + + ScopedTensorDescriptor desc; + std::vector tensor_dims = {N, dim, D, 1}; + DataLayout layout = DataLayout::kNCHW; + cudnnTensorDescriptor_t desc_ = desc.descriptor(layout, tensor_dims); + + auto& dev_ctx = ctx.template device_context(); + auto handle = dev_ctx.cudnn_handle(); + auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE + : CUDNN_SOFTMAX_MODE_CHANNEL; + + PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward( + handle, CUDNN_SOFTMAX_ACCURATE, mode, + platform::CudnnDataType::kOne(), desc_, out->data(), desc_, + dout->data(), platform::CudnnDataType::kZero(), desc_, dx_data)); } }; diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 2a6ca7975f0c591701400e71feab0be36300480b..cf46b4fc3bdad486f65afa5ac994d506c20344cb 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -53,13 +53,6 @@ class SoftmaxOp : public framework::OperatorWithKernel { "Attr(axis) value should be in range [-R, R-1], " "R is the rank of Input(X).")); - auto use_cudnn = ctx->Attrs().Get("use_cudnn"); - if (axis != rank_x - 1 && axis != -1) { - PADDLE_ENFORCE_EQ(use_cudnn, false, - platform::errors::InvalidArgument( - "CUDNN kernel only support axis as -1.")); - } - ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Out"); } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_scale_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_scale_op.py new file mode 100644 index 0000000000000000000000000000000000000000..851394710f34c739556fe11e042803101772ea56 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_scale_op.py @@ -0,0 +1,49 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from inference_pass_test import InferencePassTest +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.core import AnalysisConfig + + +class TRTScaleTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data = fluid.data(name="data", shape=[-1, 512], dtype="float32") + scale_out = self.append_scale(data) + out = fluid.layers.batch_norm(scale_out, is_test=True) + + self.feeds = {"data": np.random.random([1, 512]).astype("float32"), } + self.enable_trt = True + self.trt_parameters = TRTScaleTest.TensorRTParam( + 1 << 30, 32, 1, AnalysisConfig.Precision.Float32, False, False) + self.fetch_list = [out] + + def append_scale(self, data): + return fluid.layers.scale( + x=data, scale=2.0, bias=-1.0, bias_after_scale=False) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu, flatten=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py index 17c92a24eb066af13d51da659dcd8c44478e8f56..c9095b07e5e0dcb0377221450b7f0f19da31fc39 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_subgraph_pass.py @@ -654,6 +654,56 @@ class TensorRTSubgraphPassElementwiseMulTest( return fluid.layers.elementwise_mul(x=data1, y=data2) +class TensorRTSubgraphPassElementwiseSerializeTest( + TensorRTSubgraphPassElementwiseTest): + def setUp(self): + super(TensorRTSubgraphPassElementwiseSerializeTest, self).setUp() + self.trt_parameters = TensorRTSubgraphPassElementwiseTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False) + + def test_check_output(self): + if os.path.exists(self.path + "_opt_cache"): + shutil.rmtree(self.path + "_opt_cache") + super(TensorRTSubgraphPassElementwiseSerializeTest, + self).test_check_output() + + +class TensorRTSubgraphPassElementwiseBroadcastDynamicTest(InferencePassTest): + def setUp(self): + with fluid.program_guard(self.main_program, self.startup_program): + data1 = fluid.data( + name="data1", shape=[-1, 3, 64, 64], dtype="float32") + data2 = fluid.data(name="data2", shape=[64, 64], dtype="float32") + eltwise_out = self.append_eltwise(data1, data2) + out = fluid.layers.batch_norm(eltwise_out, is_test=True) + self.feeds = { + "data1": np.random.random([1, 3, 64, 64]).astype("float32"), + "data2": np.random.random([64, 64]).astype("float32"), + } + self.enable_trt = True + self.trt_parameters = TensorRTSubgraphPassElementwiseBroadcastDynamicTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Float32, True, False) + self.dynamic_shape_params = TensorRTSubgraphPassElementwiseBroadcastDynamicTest.DynamicShapeParam( + { + 'data1': [1, 3, 8, 64], + 'data2': [8, 64] + }, {'data1': [1, 3, 512, 64], + 'data2': + [512, 64]}, {'data1': [1, 3, 256, 64], + 'data2': [256, 64]}, False) + self.fetch_list = [out] + + def append_eltwise(self, data1, data2): + return fluid.layers.elementwise_add(x=data1, y=data2) + + def test_check_output(self): + if os.path.exists(self.path + "_opt_cache"): + shutil.rmtree(self.path + "_opt_cache") + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option(use_gpu) + + class TensorRTSubgraphPassShuffleChannelTest(InferencePassTest): def setUp(self): with fluid.program_guard(self.main_program, self.startup_program): diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index c393b55d7bd2cecbb89e271555c901e81ff7eadd..11f32e712236c339ba5cd5bbe463c3f27d74ab93 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -153,16 +153,103 @@ class TestSoftmaxCUDNNOp2(TestSoftmaxCUDNNOp): return [2, 3, 4, 5] +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp3(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 0 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp4(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + + def get_axis(self): + return 1 + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxCUDNNOp5(TestSoftmaxCUDNNOp): def get_x_shape(self): return [2, 3, 4, 5] + def get_axis(self): + return 2 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp6(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5] + def get_axis(self): return 3 +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp7(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp8(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 0 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp9(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 1 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp10(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 2 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp11(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 3 + + +@unittest.skipIf(not core.is_compiled_with_cuda(), + "core is not compiled with CUDA") +class TestSoftmaxCUDNNOp12(TestSoftmaxCUDNNOp): + def get_x_shape(self): + return [2, 3, 4, 5, 6] + + def get_axis(self): + return 4 + + @unittest.skipIf(not core.is_compiled_with_cuda(), "core is not compiled with CUDA") class TestSoftmaxFP16Op(TestSoftmaxOp):