diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index b63db0bab483ab11de137c11a6fa133455dc3937..1a1619fa969347abcc87e5ca734546c00e46c69d 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -2754,6 +2754,7 @@ USE_TRT_CONVERTER(dropout); USE_TRT_CONVERTER(pad); #if IS_TRT_VERSION_GE(8200) USE_TRT_CONVERTER(pad3d); +USE_TRT_CONVERTER(einsum) #endif USE_TRT_CONVERTER(hard_sigmoid); USE_TRT_CONVERTER(hard_swish); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 90b4cec1f9ac8118cc4ef2e0577715e46d1ba087..1064362df387860804d31491b1db664827ae3c28 100755 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -105,7 +105,8 @@ list( preln_groupnorm_act_op.cc expand_v2_op.cc cumsum_op.cc - temporal_shift_op.cc) + temporal_shift_op.cc + einsum_op.cc) if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7) list(APPEND CONVERT_FILES emb_eltwise_layernorm.cc diff --git a/paddle/fluid/inference/tensorrt/convert/einsum_op.cc b/paddle/fluid/inference/tensorrt/convert/einsum_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e43615da01c09c90d7e488a60771c5bb632c2515 --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/einsum_op.cc @@ -0,0 +1,53 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace inference { +namespace tensorrt { + +/* + * Einsum Op + */ +class EinsumOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { +#if IS_TRT_VERSION_GE(8200) + VLOG(3) << "convert a einsum op to tensorrt layer"; + framework::OpDesc op_desc(op, nullptr); + auto operand_inputs = op_desc.Input("Operands"); + auto equation = PADDLE_GET_CONST(std::string, op_desc.GetAttr("equation")); + std::vector input_tensors; + for (auto input_name : operand_inputs) { + auto tmp_tensor = engine_->GetITensor(input_name); + input_tensors.push_back(tmp_tensor); + } + + int32_t input_num = static_cast(operand_inputs.size()); + auto layer = TRT_ENGINE_ADD_LAYER( + engine_, Einsum, input_tensors.data(), input_num, equation.c_str()); + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "einsum", {output_name}, test_mode); +#else + VLOG(3) << "Einsum is not supported when TensorRT < 8.2.0"; +#endif + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(einsum, EinsumOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index 6dbb05bbff867248471c01706aa0b830b05e92df..ff6b49e79c9c18e427ee38504978c714ca874573 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -79,6 +79,8 @@ struct SimpleOpTypeSetTeller : public Teller { teller_set.insert("set_value"); teller_set.insert("index_select"); int8_teller_set.insert("index_select"); + int8_teller_set.insert("einsum"); + teller_set.insert("einsum"); #endif } @@ -2700,6 +2702,39 @@ struct SimpleOpTypeSetTeller : public Teller { } } + if (op_type == "einsum") { +#if !IS_TRT_VERSION_GE(8200) + VLOG(3) << "einsum is not supported when TensorRT < 8.2"; + return false; +#else + if (!with_dynamic_shape) { + VLOG(3) << "the einsum does not support " + "static shape yet"; + return false; + } + auto operand_inputs = desc.Input("Operands"); + if (operand_inputs.size() > 2) { + VLOG(3) << "TensorRT currently supports up to 2 input tensors" + << "to einsum but operation had" << operand_inputs.size() + << "input tensors !"; + return false; + } + + auto* block = desc.Block(); + if (block == nullptr) { + VLOG(3) << "The block desc is nullptr, we can't continue to analyze. " + "Developers need to check whether block_desc is passed in " + "the pass."; + return false; + } + auto equation = PADDLE_GET_CONST(std::string, desc.GetAttr("equation")); + if (equation.find("...") != std::string::npos) { + VLOG(3) << "TensorRT currently does not support ellipses !"; + return false; + } +#endif + } + if (use_no_calib_int8) { return int8_teller_set.count(op_type); } else { diff --git a/test/ir/inference/test_trt_convert_einsum.py b/test/ir/inference/test_trt_convert_einsum.py new file mode 100644 index 0000000000000000000000000000000000000000..6f1fb5ebdd4bd987c6808651312223fa37a43aba --- /dev/null +++ b/test/ir/inference/test_trt_convert_einsum.py @@ -0,0 +1,483 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from typing import List + +import numpy as np +from program_config import ProgramConfig, TensorConfig +from trt_layer_auto_scan_test import TrtLayerAutoScanTest + +import paddle.inference as paddle_infer + + +class TrtConvertEinsumTest_SingleOperand(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8200: + return False + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input1(dims, batch): + if dims == 1: + return np.ones(shape=[batch]).astype(np.float32) + elif dims == 2: + return np.ones(shape=[batch, 3]).astype(np.float32) + elif dims == 3: + return np.ones((batch, 2, 3)).astype(np.float32) + + def generate_equation1(dims): + if dims == 1: + return ["i->"] + elif dims == 2: + # "ij->" + return ["ij->ji", "ij->i", "ij->j"] + elif dims == 3: + # "ijk->","ijk->j","ijk->k" + # error: The current implementation of Einsum doesn't support mask dimensions on multiple contracting/free dimensions + return [ + "ijk->ikj", + "ijk->i", + "ijk->ij", + "ijk->ik", + "ijk->ijk", + "ijk->jk", + ] + + # Single operand: transpose, sum + for dims in [1, 2, 3]: + for batch in [2]: + equation_list = generate_equation1(dims) + for equation in equation_list: + self.equation = equation + self.dims = dims + dics = [ + { + "equation": equation, + } + ] + ops_config = [ + { + "op_type": "einsum", + "op_inputs": {"Operands": ["operands_data0"]}, + "op_outputs": {"Out": ["einsum_output_data"]}, + "op_attrs": dics[0], + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "operands_data0": TensorConfig( + data_gen=partial(generate_input1, dims, batch) + ) + }, + outputs=["einsum_output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = { + "operands_data0": [1], + } + self.dynamic_shape.max_input_shape = { + "operands_data0": [3], + } + self.dynamic_shape.opt_input_shape = { + "operands_data0": [2], + } + elif self.dims == 2: + self.dynamic_shape.min_input_shape = { + "operands_data0": [1, 3], + } + self.dynamic_shape.max_input_shape = { + "operands_data0": [4, 3], + } + self.dynamic_shape.opt_input_shape = { + "operands_data0": [2, 3], + } + elif self.dims == 3: + self.dynamic_shape.min_input_shape = { + "operands_data0": [1, 2, 3], + } + self.dynamic_shape.max_input_shape = { + "operands_data0": [4, 2, 3], + } + self.dynamic_shape.opt_input_shape = { + "operands_data0": [2, 2, 3], + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if (not dynamic_shape) or ("..." in self.equation): + return 0, 3 + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + + def test(self): + self.run_test() + + +class TrtConvertEinsumTest_DoubuleOperand_Vector_Matrix(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8200: + return False + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input_matrix(dims, batch): + if dims == 1: + return np.ones(shape=[batch]).astype(np.float32) + elif dims == 2: + return np.ones(shape=[batch, 3]).astype(np.float32) + elif dims == 3: + return np.ones((batch, 2, 3)).astype(np.float32) + + """ + genertate_vector + """ + + def generate_input_vector(vec_shape): + return np.ones(vec_shape).astype(np.float32) + + def generate_equation_matrix_vector(dims, vec_shape): + if dims == 1: + return ["i,i->", "i,i->i", "i,j->ij"] + elif dims == 2 and vec_shape == [3]: + return ["ij,j->i", "ij,j->j", "ij,j->ij", "ij,j", "ij,j->"] + elif dims == 3 and vec_shape == [3]: + return [ + "ijk,k->i", + "ijk,k->j", + "ijk,k->k", + "ijk,k->ij", + "ijk,k->ik", + "ijk,k->jk", + "ijk,k->ijk", + "ijk,k", + "ijk,k->", + ] + + # Doubule operands vector + for dims in [1]: + self.dims = dims + for vec_shape in [[2], [3]]: + for batch in [2]: + equation_list = generate_equation_matrix_vector( + dims, vec_shape + ) + for equation in equation_list: + if ( + dims == 1 + and vec_shape != [2] + and equation != "i,j->ij" + ) or ((dims == 2 or dims == 3) and vec_shape != [3]): + continue + self.equation = equation + self.dims = dims + dics = [{"equation": equation}, {}] + ops_config = [ + { + "op_type": "einsum", + "op_inputs": { + "Operands": [ + "operands_data0", + "operands_data1", + ] + }, + "op_outputs": {"Out": ["einsum_output_data"]}, + "op_attrs": dics[0], + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "operands_data0": TensorConfig( + data_gen=partial( + generate_input_matrix, dims, batch + ) + ), + "operands_data1": TensorConfig( + data_gen=partial( + generate_input_vector, vec_shape + ) + ), + }, + outputs=["einsum_output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + if self.dims == 1: + self.dynamic_shape.min_input_shape = { + "operands_data0": [1], + "operands_data1": [1], + } + self.dynamic_shape.max_input_shape = { + "operands_data0": [4], + "operands_data1": [4], + } + self.dynamic_shape.opt_input_shape = { + "operands_data0": [2], + "operands_data1": [2], + } + elif self.dims == 2: + self.dynamic_shape.min_input_shape = { + "operands_data0": [1, 3], + "operands_data1": [1], + } + self.dynamic_shape.max_input_shape = { + "operands_data0": [4, 3], + "operands_data1": [4], + } + self.dynamic_shape.opt_input_shape = { + "operands_data0": [2, 3], + "operands_data1": [3], + } + elif self.dims == 3: + self.dynamic_shape.min_input_shape = { + "operands_data0": [1, 2, 3], + "operands_data1": [1], + } + self.dynamic_shape.max_input_shape = { + "operands_data0": [4, 2, 3], + "operands_data1": [4], + } + self.dynamic_shape.opt_input_shape = { + "operands_data0": [2, 2, 3], + "operands_data1": [3], + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if (not dynamic_shape) or ("..." in self.equation): + return 0, 4 + return 1, 3 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + + def test(self): + self.run_test() + + +class TrtConvertEinsumTest_DoubuleOperand_Matrix_Matrix(TrtLayerAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + ver = paddle_infer.get_trt_compile_version() + if ver[0] * 1000 + ver[1] * 100 + ver[2] * 10 < 8200: + return False + return True + + def sample_program_configs(self): + self.trt_param.workspace_size = 1073741824 + + def generate_input_matrix(input_shape): + return np.ones(shape=input_shape).astype(np.float32) + + # Doubule operands vector + for item in [ + [[4, 5], [4, 5], "ij,ij->ij"], # MatrixEleMul + [[4, 5], [2, 5], "ij,kj->ik"], # MatrixMul + [[4, 5], [3, 7], "ij,kl->ijkl"], # MatrixOuter + [[3, 4, 5], [3, 5, 2], "bij,bjk->bik"], + [[3, 4, 5], [4, 5], "ijk,jk->i"], + [[3, 4, 5], [2, 5], "ijk,lk->ijl"], + [[2, 4, 5, 3], [3, 4, 5], "ijkl,lmn->ijkmn"], + [[3, 4, 5], [4, 5], "ijk,jk->ik"], + [[3, 4, 5], [4, 5], "ijk,jk->ij"], + [[4, 5], [4, 2, 5], "ik,ijk->j"], + [[4, 2, 5], [4, 5], "ijk,ik->jk"], + [[2, 4, 5, 3], [3, 2, 4], "ijkl,lmn->kmn"], + [[2, 4, 5, 3], [3, 2, 4], "ijkl,lmn->ijn"], + [[1, 3, 5], [1, 2, 3, 4], "blq,bhlk->bhlqk"], + ]: + self.x_shape = item[0] + self.y_shape = item[1] + equation = item[2] + self.equation = equation + + dics = [{"equation": equation}, {}] + ops_config = [ + { + "op_type": "einsum", + "op_inputs": { + "Operands": ["operands_data0", "operands_data1"] + }, + "op_outputs": {"Out": ["einsum_output_data"]}, + "op_attrs": dics[0], + } + ] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "operands_data0": TensorConfig( + data_gen=partial(generate_input_matrix, self.x_shape) + ), + "operands_data1": TensorConfig( + data_gen=partial(generate_input_matrix, self.y_shape) + ), + }, + outputs=["einsum_output_data"], + ) + + yield program_config + + def sample_predictor_configs( + self, program_config + ) -> (paddle_infer.Config, List[int], float): + def generate_dynamic_shape(attrs): + min_xshape = self.x_shape[:] + max_xshape = self.x_shape[:] + min_yshape = self.y_shape[:] + max_yshape = self.y_shape[:] + if "b" in self.equation: + min_xshape[0] = 1 + max_xshape[0] = 4 + min_yshape[0] = 1 + max_yshape[0] = 4 + self.dynamic_shape.min_input_shape = { + "operands_data0": min_xshape, + "operands_data1": min_yshape, + } + self.dynamic_shape.max_input_shape = { + "operands_data0": max_xshape, + "operands_data1": max_yshape, + } + self.dynamic_shape.opt_input_shape = { + "operands_data0": self.x_shape, + "operands_data1": self.y_shape, + } + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + if (not dynamic_shape) or ("..." in self.equation): + return 0, 4 + return 1, 3 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False + ), 1e-5 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True + ), 1e-5 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main()