From ffc8defa62d66a476cec8512ede44ad9ec089c62 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Thu, 4 Aug 2022 16:55:22 +0800 Subject: [PATCH] [Paddle-TRT] add Rnn (#44678) * add rnn --- .../fluid/inference/api/analysis_predictor.cc | 2 + .../inference/tensorrt/convert/CMakeLists.txt | 2 + .../fill_constant_batch_size_like_op.cc | 86 +++++ .../inference/tensorrt/convert/rnn_op.cc | 320 ++++++++++++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 55 +++ .../ir/inference/test_trt_convert_rnn.py | 253 ++++++++++++++ 6 files changed, 718 insertions(+) create mode 100644 paddle/fluid/inference/tensorrt/convert/fill_constant_batch_size_like_op.cc create mode 100644 paddle/fluid/inference/tensorrt/convert/rnn_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_rnn.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index bde92c13b4c..9c673dfc575 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2096,6 +2096,8 @@ USE_TRT_CONVERTER(preln_residual_bias) USE_TRT_CONVERTER(c_allreduce_sum) USE_TRT_CONVERTER(roll) USE_TRT_CONVERTER(strided_slice) +USE_TRT_CONVERTER(rnn) +USE_TRT_CONVERTER(fill_constant_batch_size_like) USE_TRT_CONVERTER(transformer_input_convert) USE_TRT_CONVERTER(cast) USE_TRT_CONVERTER(recover_padding) diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 4a13b6c00ac..4f563c2df8e 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -69,6 +69,8 @@ list( top_k_op.cc squeeze2_op.cc unsqueeze2_op.cc + rnn_op.cc + fill_constant_batch_size_like_op.cc sum_op.cc shape_op.cc fill_constant_op.cc diff --git a/paddle/fluid/inference/tensorrt/convert/fill_constant_batch_size_like_op.cc b/paddle/fluid/inference/tensorrt/convert/fill_constant_batch_size_like_op.cc new file mode 100644 index 00000000000..5f00777a663 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/fill_constant_batch_size_like_op.cc @@ -0,0 +1,86 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class FillConstantBatchSizeLikeOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { +#if IS_TRT_VERSION_GE(7000) + VLOG(4) << "convert a fluid fill_constant_batch_size_like op to tensorrt " + "fill_constant_batch_size_like layer"; + + framework::OpDesc op_desc(op, nullptr); + auto* input = engine_->GetITensor(op_desc.Input("Input")[0]); + int dtype = PADDLE_GET_CONST(int, op_desc.GetAttr("dtype")); + // be float + PADDLE_ENFORCE_EQ(dtype, + 5, + platform::errors::InvalidArgument( + "fill_constant_batch_size_like's input data type " + "must be float in Paddle-TRT.")); + + int input_dim_idx = PADDLE_GET_CONST(int, op_desc.GetAttr("input_dim_idx")); + size_t output_dim_idx = + PADDLE_GET_CONST(int, op_desc.GetAttr("output_dim_idx")); + std::string str_value = + PADDLE_GET_CONST(std::string, op_desc.GetAttr("str_value")); + std::vector shape = + PADDLE_GET_CONST(std::vector, op_desc.GetAttr("shape")); + float value = std::stof(str_value); + + auto* input_shape_tensor = Shape(input); + auto* batch_tensor = GetEleTensorOfShape(input_shape_tensor, input_dim_idx); + std::string name = "_add_fill_constant_batch_size_like_op_"; + auto shape_attr_tensor = Add1DConstantLayer(shape, name + "shape_attr"); + std::vector gather_out_shape_indices; + for (size_t i = 0; i < shape.size(); i++) { + if (i == output_dim_idx) { + gather_out_shape_indices.push_back(shape.size()); + continue; + } + gather_out_shape_indices.push_back(i); + } + std::vector concat_inputs{shape_attr_tensor, + batch_tensor}; + auto out_shape_tensor = + Gather(Concat(concat_inputs), gather_out_shape_indices); + auto layer = TRT_ENGINE_ADD_LAYER( + engine_, Fill, nvinfer1::Dims{}, nvinfer1::FillOperation::kLINSPACE); + std::vector value_vec(1, value); + std::vector beta_vec(3, 0.); + layer->setAlpha(value); + layer->setBeta(0.f); + layer->setInput(0, *out_shape_tensor); + layer->setInput(1, *Add1DConstantLayer(value_vec, name + "alpha", true)); + layer->setInput(2, *Add1DConstantLayer(beta_vec, name + "beta", false)); + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput( + layer, "fill_constant_batch_size_like", {output_name}, test_mode); +#endif + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(fill_constant_batch_size_like, + FillConstantBatchSizeLikeOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/rnn_op.cc b/paddle/fluid/inference/tensorrt/convert/rnn_op.cc new file mode 100644 index 00000000000..945495c0d16 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/rnn_op.cc @@ -0,0 +1,320 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +class RnnNativeOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { +#if IS_TRT_VERSION_GE(7000) + VLOG(4) << "convert a fluid rnn op to tensorrt rnn layer"; + + framework::OpDesc op_desc(op, nullptr); + // [seq_len, batch ,in_size], + // [K * num_layers, batch ,in_size], [K * num_layers, batch ,in_size] + // K is defined below + auto* input = engine_->GetITensor(op_desc.Input("Input")[0]); + auto* prev_c = engine_->GetITensor(op_desc.Input("PreState")[0]); + auto* prev_h = engine_->GetITensor(op_desc.Input("PreState")[1]); + + PADDLE_ENFORCE_EQ(input->getDimensions().nbDims, + 3, + platform::errors::InvalidArgument( + "RNN(LSTM)'s input must be 3 dimensions, i.e. " + "[seq_len, batch, input_size]," + "but now is %d dimensions.", + input->getDimensions().nbDims)); + + PADDLE_ENFORCE_EQ(prev_h->getDimensions().nbDims, + 3, + platform::errors::InvalidArgument( + "RNN(LSTM)'s PreState(Hidden) must be 3 dimensions, " + "i.e. [num_layers, batch, hidden_size]," + "but now is %d dimensions.", + prev_h->getDimensions().nbDims)); + + PADDLE_ENFORCE_EQ(prev_c->getDimensions().nbDims, + 3, + platform::errors::InvalidArgument( + "RNN(LSTM)'s PreState(Cell) must be 3 dimensions, " + "i.e. [num_layers, batch, hidden_size]," + "but now is %d dimensions.", + prev_c->getDimensions().nbDims)); + + int num_layers = PADDLE_GET_CONST(int, op_desc.GetAttr("num_layers")); + int hidden_size = PADDLE_GET_CONST(int, op_desc.GetAttr("hidden_size")); + int input_size = PADDLE_GET_CONST(int, op_desc.GetAttr("input_size")); + bool is_bidirec = PADDLE_GET_CONST(bool, op_desc.GetAttr("is_bidirec")); + int K = is_bidirec ? 2 : 1; + + // extract weights + // if is_bidirec, make forward and backward weight/bias concated + std::vector weight_bias_vec; + for (int layer_id = 0; layer_id < num_layers; layer_id++) { + if (is_bidirec) { + auto extract_and_combine_weight = [&](int start) { + // k and k + 2 is combined ! + // k + 1 and k + 3 is combined ! + for (int k = 0; k < K; k++) { + std::string var0_name = op_desc.Input("WeightList")[k + start]; + std::string var1_name = op_desc.Input("WeightList")[k + 2 + start]; + auto* var0_v = scope.FindVar(var0_name); + auto* var1_v = scope.FindVar(var1_name); + auto* var0_t = var0_v->GetMutable(); + auto* var1_t = var1_v->GetMutable(); + const float* data0_ptr = reinterpret_cast( + engine_->GetTrtWeight(var0_name, *var0_t).get().values); + const float* data1_ptr = reinterpret_cast( + engine_->GetTrtWeight(var1_name, *var1_t).get().values); + float* data_ptr = new float[K * var0_t->numel()]; + // remember free + memcpy(data_ptr, data0_ptr, sizeof(float) * var0_t->numel()); + memcpy(data_ptr + var0_t->numel(), + data1_ptr, + sizeof(float) * var1_t->numel()); + weight_bias_vec.push_back(data_ptr); + } + }; + extract_and_combine_weight(4 * layer_id); + extract_and_combine_weight(4 * layer_id + 4 * num_layers); + } else { + auto extract_weight = [&](int start) { + for (int k = 0; k < 2 * K; k++) { + std::string var_name = op_desc.Input("WeightList")[k + start]; + auto* var_v = scope.FindVar(var_name); + auto* var_t = var_v->GetMutable(); + const float* data_ptr = reinterpret_cast( + engine_->GetTrtWeight(var_name, *var_t).get().values); + weight_bias_vec.push_back(data_ptr); + } + }; + extract_weight(2 * layer_id); // filter + extract_weight(2 * num_layers + 2 * layer_id); // bias + } + } + // [seq_len, batch ,in_size] + + nvinfer1::ITensor* this_input = + TRT_ENGINE_ADD_LAYER(engine_, Identity, *input)->getOutput(0); + + nvinfer1::ILayer* finally_layer = nullptr; + for (int layer_id = 0; layer_id < num_layers; layer_id++) { + auto* loop = TRT_ENGINE_ADD_LAYER(engine_, Loop); + auto* input_shape_tensor = Shape(this_input); + auto* seq_len_scalar = GetEleTensorOfShape(input_shape_tensor, 0, true); + auto* seq_len_tensor = GetEleTensorOfShape(input_shape_tensor, 0); + auto* batch_tensor = GetEleTensorOfShape(input_shape_tensor, 1); + auto* K_tensor = Add1DConstantLayer(K); + auto* hidden_size_tensor = Add1DConstantLayer(hidden_size); + + if (layer_id > 0) input_size = K * hidden_size; + auto* input_size_tensor = Add1DConstantLayer(input_size); + + loop->addTripLimit(*seq_len_scalar, nvinfer1::TripLimit::kCOUNT); + + nvinfer1::ITensor* iter_input_tensor; + auto* iter_input_forward_tensor = + loop->addIterator(*this_input)->getOutput(0); // [batch, input_size] + + // this function shuffle tensor -> 4 dims + auto reshape2four = [&](nvinfer1::ITensor** tensor) { +#if TRT_VERSION == 7234 + auto* tmp_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, **tensor); + std::vector concat_inputs{ + Add1DConstantLayer(1), Add1DConstantLayer(1), Shape(*tensor)}; + tmp_layer->setInput(1, *Concat(concat_inputs)); + *tensor = tmp_layer->getOutput(0); +#endif + }; + + reshape2four(&iter_input_forward_tensor); + + if (is_bidirec) { + auto* iter_input_reverse_tensor = + loop->addIterator(*this_input, 0, true) + ->getOutput(0); // [batch, input_size] + + reshape2four(&iter_input_reverse_tensor); + + std::vector concat_inputs{ + iter_input_forward_tensor, iter_input_reverse_tensor}; + iter_input_tensor = Concat(concat_inputs); + } else { + iter_input_tensor = iter_input_forward_tensor; + } + + auto* tmp_layer = + TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *iter_input_tensor); + + tmp_layer->setInput(1, + *Concat(std::vector{ + K_tensor, batch_tensor, input_size_tensor})); + + iter_input_tensor = tmp_layer->getOutput(0); + // [K, batch, input_size] + + std::vector tmp_vec(K); + std::iota(tmp_vec.begin(), tmp_vec.end(), 2 * layer_id); + auto* first_prev_h = Gather(prev_h, tmp_vec); + auto* first_prev_c = Gather(prev_c, tmp_vec); + + nvinfer1::IRecurrenceLayer* Hlayer = loop->addRecurrence(*first_prev_h); + nvinfer1::IRecurrenceLayer* Clayer = loop->addRecurrence(*first_prev_c); + + // k is weight + // k + 2 is bias + auto run_matmul_bias = [&](int k, bool is_input) -> nvinfer1::ITensor* { + int h = 4 * hidden_size; + int w = is_input ? input_size : hidden_size; + if (is_input && k > 0) w = K * hidden_size; + + auto weight_shape = nvinfer1::Dims3{K, h, w}; + auto* weight_tensor = + AddConstantLayer(weight_bias_vec[k], weight_shape, " "); + auto bias_shape = nvinfer1::Dims3{K, 1, h}; + auto* bias_tensor = + AddConstantLayer(weight_bias_vec[k + 2], bias_shape, " "); + + nvinfer1::ITensor* iter_tensor = + k % 2 ? Hlayer->getOutput(0) : iter_input_tensor; + + auto* iter_w_tensor = + TRT_ENGINE_ADD_LAYER(engine_, + MatrixMultiply, + *iter_tensor, + nvinfer1::MatrixOperation::kNONE, + *weight_tensor, + nvinfer1::MatrixOperation::kTRANSPOSE) + ->getOutput(0); + + auto* iter_w_b_tensor = Sum(iter_w_tensor, bias_tensor); + return iter_w_b_tensor; + }; + + nvinfer1::ITensor* iter_input_w_b_tensor = + run_matmul_bias(layer_id * 4, true); + nvinfer1::ITensor* iter_hidden_w_b_tensor = + run_matmul_bias(layer_id * 4 + 1, false); + auto* iter_input_hidden_add_tensor = + Sum(iter_input_w_b_tensor, iter_hidden_w_b_tensor); + + nvinfer1::Dims start_dims = nvinfer1::Dims3{0, 0, 0}; + nvinfer1::Dims size_dims = nvinfer1::Dims3{0, 0, 0}; + auto* size_dims_tensor = Concat(std::vector{ + K_tensor, batch_tensor, hidden_size_tensor}); + nvinfer1::Dims step_dims = nvinfer1::Dims3{1, 1, 1}; + + std::vector lstm_act{ + nvinfer1::ActivationType::kSIGMOID, nvinfer1::ActivationType::kTANH}; + + auto split_gate = [&](int i, int act_i = 0) -> nvinfer1::ITensor* { + start_dims.d[2] = i * hidden_size; + auto* gate_layer = TRT_ENGINE_ADD_LAYER(engine_, + Slice, + *iter_input_hidden_add_tensor, + start_dims, + size_dims, + step_dims); + gate_layer->setInput(2, *size_dims_tensor); + auto* gate = gate_layer->getOutput(0); + gate = Act(gate, lstm_act[act_i]); + return gate; + }; + + auto* i_gate = split_gate(0); + auto* f_gate = split_gate(1); + auto* c_gate = split_gate(2, 1); + auto* o_gate = split_gate(3); + + // C_t = i_gate * c_gate + f_gate * C_{t-1} + auto* ic_gate = Prod(i_gate, c_gate); + auto* fCt1_gate = Prod(f_gate, Clayer->getOutput(0)); + auto* Ct = Sum(ic_gate, fCt1_gate); + Clayer->setInput(1, *Ct); + // H_t = tanh(C_t) * o_gate + auto* tanh_Ct = Act(Ct, lstm_act[1]); + auto* Ht = Prod(o_gate, tanh_Ct); + Hlayer->setInput(1, *Ht); + + // Ht: [K, batch, hidden_size] + nvinfer1::ILayer* layer = nullptr; + nvinfer1::ITensor* tensor = nullptr; + if (is_bidirec) { + auto* slice_forward_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Slice, + *Ht, + nvinfer1::Dims3{0, 0, 0}, + nvinfer1::Dims3{0, 0, 0}, + nvinfer1::Dims3{1, 1, 1}); + auto* slice_reverse_layer = + TRT_ENGINE_ADD_LAYER(engine_, + Slice, + *Ht, + nvinfer1::Dims3{1, 0, 0}, + nvinfer1::Dims3{0, 0, 0}, + nvinfer1::Dims3{1, 1, 1}); + auto* one_tensor = Add1DConstantLayer(1); + auto* size_dims_tensor = Concat(std::vector{ + one_tensor, batch_tensor, hidden_size_tensor}); + slice_forward_layer->setInput(2, *size_dims_tensor); + slice_reverse_layer->setInput(2, *size_dims_tensor); + + auto* layer0 = loop->addLoopOutput(*slice_forward_layer->getOutput(0), + nvinfer1::LoopOutput::kCONCATENATE); + auto* layer1 = loop->addLoopOutput(*slice_reverse_layer->getOutput(0), + nvinfer1::LoopOutput::kREVERSE); + layer0->setInput(1, *seq_len_scalar); + layer1->setInput(1, *seq_len_scalar); + + std::vector concat_inputs{layer0->getOutput(0), + layer1->getOutput(0)}; + tensor = Concat(concat_inputs, 3); + } else { + layer = loop->addLoopOutput(*Ht, nvinfer1::LoopOutput::kCONCATENATE); + layer->setInput(1, *seq_len_scalar); + tensor = layer->getOutput(0); + } + finally_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *tensor); + auto* hidden_size_k_tensor = Add1DConstantLayer(hidden_size * K); + nvinfer1::ITensor* final_dims_tensor = + Concat(std::vector{ + seq_len_tensor, batch_tensor, hidden_size_k_tensor}); + finally_layer->setInput(1, *final_dims_tensor); + // update input + this_input = finally_layer->getOutput(0); + } + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(finally_layer, "rnn", {output_name}, test_mode); + // free + if (is_bidirec) { + for (size_t i = 0; i < weight_bias_vec.size(); i++) + delete[] weight_bias_vec[i]; + } +#endif + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(rnn, RnnNativeOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6bbdbe90cc7..6ed6ba57075 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -40,6 +40,10 @@ struct SimpleOpTypeSetTeller : public Teller { #if IS_TRT_VERSION_GE(7000) teller_set.insert("tile"); teller_set.insert("flatten_contiguous_range"); + teller_set.insert("rnn"); + int8_teller_set.insert("rnn"); + teller_set.insert("fill_constant_batch_size_like"); + int8_teller_set.insert("fill_constant_batch_size_like"); #endif #if CUDA_VERSION >= 10020 teller_set.insert("reshape"); @@ -1249,6 +1253,57 @@ bool OpTeller::Tell(const framework::ir::Node* node, } } + if (op_type == "rnn") { + if (!with_dynamic_shape) { + return false; + } + if (desc.HasAttr("mode")) { + std::string mode = PADDLE_GET_CONST(std::string, desc.GetAttr("mode")); + if (mode != "LSTM") return false; + } + if (desc.HasAttr("dropout_prob")) { + float dropout_prob = + PADDLE_GET_CONST(float, desc.GetAttr("dropout_prob")); + if (dropout_prob > 1e-5) return false; + } + // not support following four inputs for rnn in paddle-trt + auto rnn_inputs = desc.Inputs(); + if (rnn_inputs.find("SequenceLength") != rnn_inputs.end()) { + if (desc.Input("SequenceLength").size()) { + return false; + } + } + } + + if (op_type == "fill_constant_batch_size_like") { + if (!with_dynamic_shape) { + return false; + } + if (!desc.HasAttr("input_dim_idx")) { + return false; + } + if (!desc.HasAttr("output_dim_idx")) { + return false; + } + if (!desc.HasAttr("shape")) { + return false; + } + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + auto x_var_name = desc.Input("Input")[0]; + auto* x_var_desc = block->FindVar(x_var_name); + auto dtype = x_var_desc->GetDataType(); + // At present, only support float32 into trt. + if (dtype != 5) { + return false; + } + } + if (op_type == "slice") { if (desc.HasAttr("decrease_axis")) { std::vector decrease_axis = diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_rnn.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_rnn.py new file mode 100644 index 00000000000..2a3c25bab11 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_rnn.py @@ -0,0 +1,253 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set +import unittest +import os + + +class TrtConvertSliceTest(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + for hidden_size in [30]: + for input_size in [30]: + for batch in [2]: + for seq_len in [5]: + for num_layers in [1, 2]: + for is_bidirec in [True, False]: + dics = [] + dics.append({ + "hidden_size": hidden_size, + "input_size": input_size, + "num_layers": num_layers, + "mode": "LSTM", + "is_bidirec": is_bidirec, + "is_test": True, + "dropout_prob": 0.0, + # for my convience + "batch": batch, + "seq_len": seq_len, + }) + + K = 1 + if (dics[0]["is_bidirec"]): + K = 2 + + def generate_input1(): + return np.random.random([ + batch, seq_len, input_size + ]).astype(np.float32) * 2 - 1 + + # initial input -> hidden + def generate_w0(): + return np.random.random([ + 4 * hidden_size, input_size + ]).astype(np.float32) * 2 - 1 + + # prev layer's output -> hidden + def generate_w1(): + return np.random.random([ + 4 * hidden_size, K * hidden_size + ]).astype(np.float32) * 2 - 1 + + # + def generate_w2(): + return np.random.random([ + 4 * hidden_size, hidden_size + ]).astype(np.float32) * 2 - 1 + + def generate_b(): + return np.random.random([ + 4 * hidden_size + ]).astype(np.float32) * 2 - 1 + + dics.append({ + "dtype": + 5, + "input_dim_idx": + 0, + "str_value": + "0.0", + "shape": [K * num_layers, -1, hidden_size], + "output_dim_idx": + 1, + }) + dics.append({"axis": [1, 0, 2]}) + # set weights + WeightList = [ + "weight" + str(i) + for i in range(4 * K * + dics[0]["num_layers"]) + ] + weights = {} + for i in range((int)(len(WeightList) / 2)): + # mean this weight : input->hidden + # input has 2 case: initial input input_size, K * hidden form the prev layer. + if (i % 2 == 0): + if (i <= K): + weights[ + WeightList[i]] = TensorConfig( + data_gen=partial( + generate_w0)) + else: + weights[ + WeightList[i]] = TensorConfig( + data_gen=partial( + generate_w1)) + # mean this weight : hidden->hidden + if (i % 2 == 1): + weights[WeightList[i]] = TensorConfig( + data_gen=partial(generate_w2)) + for i in range((int)(len(WeightList) / 2), + len(WeightList)): + weights[WeightList[i]] = TensorConfig( + data_gen=partial(generate_b)) + ops_config = [ + { + "op_type": + "fill_constant_batch_size_like", + "op_inputs": { + "Input": ["input_data"] + }, + "op_outputs": { + "Out": ["prestate1"] + }, + "op_attrs": dics[1] + }, + { + "op_type": + "fill_constant_batch_size_like", + "op_inputs": { + "Input": ["input_data"] + }, + "op_outputs": { + "Out": ["prestate2"] + }, + "op_attrs": dics[1] + }, + { + "op_type": "transpose2", + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": { + "Out": ["rnn_input_data"] + }, + "op_attrs": dics[2] + }, + { + "op_type": "rnn", + "op_inputs": { + "Input": ["rnn_input_data"], + # prev_c, prev_h + "PreState": + ["prestate1", "prestate2"], + "WeightList": WeightList, + }, + "op_outputs": { + "Out": ["rnn_output_data"], + "State": [ + "state_output_data0", + "state_output_data1" + ], + "Reserve": ["reserve_data"], + "DropoutState": + ["DropoutState_data"] + }, + "op_attrs": dics[0] + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights=weights, + inputs={ + "input_data": + TensorConfig( + data_gen=partial(generate_input1)) + }, + outputs=["rnn_output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + num_layers = attrs[3]["num_layers"] + hidden_size = attrs[3]["hidden_size"] + batch = attrs[3]["batch"] + input_size = attrs[3]["input_size"] + seq_len = attrs[3]["seq_len"] + + K = 1 + if attrs[3]["is_bidirec"]: + K = 2 + + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = { + "input_data": [batch - 1, seq_len, input_size], + } + self.dynamic_shape.max_input_shape = { + "input_data": [batch + 1, seq_len, input_size], + } + self.dynamic_shape.opt_input_shape = { + "input_data": [batch, seq_len, input_size], + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # The output has diff between gpu and trt in PR-CI-Windows-Inference + tol_fp32 = 1e-5 + tol_half = 1e-2 + if (os.name == 'nt'): + tol_fp32 = 1e-2 + tol_half = 1e-1 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), tol_fp32 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), tol_half + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab