From b6bf8994bd176448929057628db05b226bc74ce3 Mon Sep 17 00:00:00 2001 From: ccrrong <101700995+ccrrong@users.noreply.github.com> Date: Thu, 23 Jun 2022 11:19:17 +0800 Subject: [PATCH] add cast trt converter (#43447) * add cast trt converter --- .../fluid/inference/api/analysis_predictor.cc | 143 ++++++++++++------ .../inference/tensorrt/convert/CMakeLists.txt | 1 + .../inference/tensorrt/convert/cast_op.cc | 69 +++++++++ paddle/fluid/inference/tensorrt/op_teller.cc | 65 +++++--- .../ir/inference/test_trt_convert_cast.py | 120 +++++++++++++++ 5 files changed, 331 insertions(+), 67 deletions(-) create mode 100644 paddle/fluid/inference/tensorrt/convert/cast_op.cc create mode 100644 python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 381d242c890..c32edc3650a 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -104,7 +104,8 @@ bool IsPersistable(const framework::VarDesc *var) { } } // namespace -bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, +bool PaddleTensorToLoDTensor(const PaddleTensor &pt, + framework::LoDTensor *t, const platform::Place &place) { framework::DDim ddim = phi::make_ddim(pt.shape); void *input_ptr; @@ -132,18 +133,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, if (platform::is_cpu_place(place)) { // TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy. - std::memcpy(static_cast(input_ptr), pt.data.data(), - pt.data.length()); + std::memcpy( + static_cast(input_ptr), pt.data.data(), pt.data.length()); } else if (platform::is_ipu_place(place)) { #ifdef PADDLE_WITH_IPU - std::memcpy(static_cast(input_ptr), pt.data.data(), - pt.data.length()); + std::memcpy( + static_cast(input_ptr), pt.data.data(), pt.data.length()); #else PADDLE_THROW(paddle::platform::errors::Fatal( "Not compile with WITH_IPU, should not reach here.")); #endif } else if (platform::is_gpu_place(place)) { - PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), false, + PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), + false, platform::errors::InvalidArgument( "Only one choice can be made between CPU and XPU.")); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) @@ -151,8 +153,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, auto *dev_ctx = static_cast(pool.Get(place)); auto dst_gpu_place = place; - memory::Copy(dst_gpu_place, static_cast(input_ptr), - platform::CPUPlace(), pt.data.data(), pt.data.length(), + memory::Copy(dst_gpu_place, + static_cast(input_ptr), + platform::CPUPlace(), + pt.data.data(), + pt.data.length(), dev_ctx->stream()); #else PADDLE_THROW(paddle::platform::errors::Fatal( @@ -161,8 +166,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t, } else if (platform::is_xpu_place(place)) { #ifdef PADDLE_WITH_XPU auto dst_xpu_place = place; - memory::Copy(dst_xpu_place, static_cast(input_ptr), - platform::CPUPlace(), pt.data.data(), pt.data.length()); + memory::Copy(dst_xpu_place, + static_cast(input_ptr), + platform::CPUPlace(), + pt.data.data(), + pt.data.length()); #else PADDLE_THROW(paddle::platform::errors::Fatal( "Not compile with XPU, should not reach here.")); @@ -245,7 +253,8 @@ bool AnalysisPredictor::Init( void AnalysisPredictor::InitPlace() { if (config_.use_gpu()) { - PADDLE_ENFORCE_EQ(config_.use_xpu(), false, + PADDLE_ENFORCE_EQ(config_.use_xpu(), + false, platform::errors::InvalidArgument( "Only one choice can be made between CPU and XPU.")); place_ = paddle::platform::CUDAPlace(config_.gpu_device_id()); @@ -502,7 +511,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) { } static void DisablePrepareDataOpt( - std::shared_ptr inference_program, int block, + std::shared_ptr inference_program, + int block, bool pre_disable_opt) { bool disable_opt = false; auto &infer_block = inference_program->Block(block); @@ -512,8 +522,8 @@ static void DisablePrepareDataOpt( } if (op->HasAttr("sub_block")) { int blockID = op->GetBlockAttrId("sub_block"); - DisablePrepareDataOpt(inference_program, blockID, - disable_opt || pre_disable_opt); + DisablePrepareDataOpt( + inference_program, blockID, disable_opt || pre_disable_opt); } // disable prepare data if unfriendly op is found if (!disable_opt) { @@ -531,8 +541,8 @@ bool AnalysisPredictor::PrepareExecutor() { #endif DisablePrepareDataOpt(inference_program_, 0, false); - executor_->Prepare(sub_scope_, *inference_program_, 0, - config_.use_feed_fetch_ops_); + executor_->Prepare( + sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_); PADDLE_ENFORCE_NOT_NULL(sub_scope_, platform::errors::PreconditionNotMet( @@ -578,8 +588,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() { feed_fetch_vars.emplace_back(pair.second); } fleet_exe_->Init(config_.dist_config().carrier_id(), - *(inference_program_.get()), scope_.get(), place_, 1, - {task_node_.get()}, id_to_rank, feed_fetch_vars); + *(inference_program_.get()), + scope_.get(), + place_, + 1, + {task_node_.get()}, + id_to_rank, + feed_fetch_vars); return true; } @@ -616,8 +631,12 @@ bool AnalysisPredictor::CommInit() { peer_endpoints.emplace_back( config_.dist_config().trainer_endpoints()[rank]); } - InsertCommOp(var_name_base + std::to_string(order), ranks_in_group, - rank_in_group, peer_endpoints, comm_init_block, ring_id); + InsertCommOp(var_name_base + std::to_string(order), + ranks_in_group, + rank_in_group, + peer_endpoints, + comm_init_block, + ring_id); order += 1; } framework::NaiveExecutor e(place_); @@ -629,8 +648,11 @@ bool AnalysisPredictor::CommInit() { } void AnalysisPredictor::InsertCommOp( - std::string tmp_var_name, int nranks, int rank, - const std::vector &peer_endpoints, framework::BlockDesc *block, + std::string tmp_var_name, + int nranks, + int rank, + const std::vector &peer_endpoints, + framework::BlockDesc *block, int ring_id) { /* * tmp_var_name: the var name for var comm_id @@ -687,7 +709,8 @@ bool AnalysisPredictor::LoadConverterConfig( << config_.dist_config().comm_init_config() << "\n"; std::ifstream fin(config_.dist_config().comm_init_config(), std::ios::in); PADDLE_ENFORCE_EQ( - static_cast(fin.is_open()), true, + static_cast(fin.is_open()), + true, platform::errors::NotFound( "Cannot open file %s, please confirm whether the file is normal.", config_.dist_config().comm_init_config())); @@ -831,8 +854,9 @@ bool AnalysisPredictor::Run(const std::vector &inputs, timer.tic(); // set feed variable framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get(); - PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::PreconditionNotMet( - "The scope should not be nullptr.")); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::PreconditionNotMet("The scope should not be nullptr.")); if (!SetFeed(inputs, scope)) { LOG(ERROR) << "fail to set feed"; return false; @@ -935,9 +959,11 @@ bool AnalysisPredictor::GetFetch(std::vector *outputs, for (size_t i = 0; i < fetches_.size(); ++i) { int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col")); PADDLE_ENFORCE_EQ( - static_cast(idx), i, + static_cast(idx), + i, platform::errors::InvalidArgument( - "Fetch op's col attr(%d) should be equal to the index(%d)", idx, + "Fetch op's col attr(%d) should be equal to the index(%d)", + idx, i)); framework::FetchType &fetch_var = framework::GetFetchVariable(*scope, "fetch", idx); @@ -978,7 +1004,8 @@ void AnalysisPredictor::PrepareArgument() { if (!config_.model_dir().empty()) { argument_.SetModelDir(config_.model_dir()); } else { - PADDLE_ENFORCE_EQ(config_.prog_file().empty(), false, + PADDLE_ENFORCE_EQ(config_.prog_file().empty(), + false, platform::errors::PreconditionNotMet( "Either model_dir or prog_file should be set.")); std::string dir = inference::analysis::GetDirRoot(config_.prog_file()); @@ -1123,7 +1150,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() { Analyzer().Run(&argument_); PADDLE_ENFORCE_EQ( - argument_.scope_valid(), true, + argument_.scope_valid(), + true, platform::errors::InvalidArgument("The argument scope should be valid.")); VLOG(5) << "to prepare executor"; ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program); @@ -1173,7 +1201,8 @@ CreatePaddlePredictor( } VLOG(3) << "create AnalysisConfig"; PADDLE_ENFORCE_EQ( - config.is_valid(), true, + config.is_valid(), + true, platform::errors::InvalidArgument( "Note: Each config can only be used for one predictor.")); @@ -1190,11 +1219,13 @@ CreatePaddlePredictor( std::call_once(gflags_initialized, [&]() { std::vector gflags; PADDLE_ENFORCE_GE( - config.memory_pool_init_size_mb(), 0.f, + config.memory_pool_init_size_mb(), + 0.f, platform::errors::InvalidArgument( "The size of memory pool should be greater than 0.")); PADDLE_ENFORCE_GE( - config.gpu_device_id(), 0, + config.gpu_device_id(), + 0, platform::errors::InvalidArgument( "Invalid device id (%d). The device id should be greater than 0.", config.gpu_device_id())); @@ -1303,8 +1334,9 @@ void AnalysisPredictor::PrepareFeedFetch() { } void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) { - PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::InvalidArgument( - "The scope should not be nullptr.")); + PADDLE_ENFORCE_NOT_NULL( + scope, + platform::errors::InvalidArgument("The scope should not be nullptr.")); auto *var = scope->Var("feed"); var->GetMutable(); var = scope->Var("fetch"); @@ -1325,8 +1357,9 @@ AnalysisPredictor::GetInputTensorShape() { std::vector names = GetInputNames(); for (std::string name : names) { auto *var = inference_program_->Block(0).FindVar(name); - PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet( - "Input %s does not exist.", name)); + PADDLE_ENFORCE_NOT_NULL( + var, + platform::errors::PreconditionNotMet("Input %s does not exist.", name)); input_shapes[name] = var->GetShape(); } return input_shapes; @@ -1565,7 +1598,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() { std::vector> counter; for (auto &it : m) counter.push_back(it); std::sort( - counter.begin(), counter.end(), + counter.begin(), + counter.end(), [](std::pair &a, std::pair &b) { return a.second > b.second; }); @@ -1587,8 +1621,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() { opt_shapes[name] = opt_shape; } - inference::SerializeShapeRangeInfo(config_.shape_range_info_path(), - min_shapes, max_shapes, opt_shapes); + inference::SerializeShapeRangeInfo( + config_.shape_range_info_path(), min_shapes, max_shapes, opt_shapes); } bool AnalysisPredictor::LoadProgramDesc() { @@ -1608,7 +1642,8 @@ bool AnalysisPredictor::LoadProgramDesc() { return false; } LOG(ERROR) << string::Sprintf( - "not valid model path '%s' or program path '%s'.", config_.model_dir(), + "not valid model path '%s' or program path '%s'.", + config_.model_dir(), config_.params_file()); return false; } @@ -1620,7 +1655,8 @@ bool AnalysisPredictor::LoadProgramDesc() { // Read binary std::ifstream fin(filename, std::ios::in | std::ios::binary); PADDLE_ENFORCE_EQ( - static_cast(fin.is_open()), true, + static_cast(fin.is_open()), + true, platform::errors::NotFound( "Cannot open file %s, please confirm whether the file is normal.", filename)); @@ -1722,7 +1758,8 @@ void AnalysisPredictor::ClearIntermediateTensor() { #if PADDLE_WITH_TENSORRT bool AnalysisPredictor::SaveTrtCalibToDisk() { - PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(), true, + PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(), + true, platform::errors::PreconditionNotMet( "This func can be invoked only in trt mode")); auto &block = inference_program_->Block(0); @@ -1963,6 +2000,7 @@ USE_TRT_CONVERTER(c_allreduce_sum) USE_TRT_CONVERTER(roll) USE_TRT_CONVERTER(strided_slice) USE_TRT_CONVERTER(transformer_input_convert) +USE_TRT_CONVERTER(cast) USE_TRT_CONVERTER(recover_padding) USE_TRT_CONVERTER(remove_padding) USE_TRT_CONVERTER(top_k) @@ -1990,8 +2028,10 @@ Predictor::Predictor(const Config &config) { << "Paddle2ONNX do't support convert the Model, fall back to using " "Paddle Inference."; } else { - predictor_ = paddle::CreatePaddlePredictor< - Config, paddle::PaddleEngineKind::kONNXRuntime>(config); + predictor_ = + paddle::CreatePaddlePredictor( + config); return; } #else @@ -2001,8 +2041,10 @@ Predictor::Predictor(const Config &config) { "fall back to using Paddle Inference."; #endif } - predictor_ = paddle::CreatePaddlePredictor< - Config, paddle::PaddleEngineKind::kAnalysis>(config); + predictor_ = + paddle::CreatePaddlePredictor( + config); } std::vector Predictor::GetInputNames() { @@ -2086,7 +2128,8 @@ std::shared_ptr CreatePredictor(const Config &config) { // NOLINT namespace services { PredictorPool::PredictorPool(const Config &config, size_t size) { PADDLE_ENFORCE_GE( - size, 1UL, + size, + 1UL, paddle::platform::errors::InvalidArgument( "The predictor pool size should be greater than 1, but it's (%d)", size)); @@ -2105,9 +2148,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) { Predictor *PredictorPool::Retrive(size_t idx) { PADDLE_ENFORCE_LT( - idx, preds_.size() + 1, + idx, + preds_.size() + 1, paddle::platform::errors::InvalidArgument( - "There are (%d) predictors in the pool, but the idx is (%d)", idx, + "There are (%d) predictors in the pool, but the idx is (%d)", + idx, preds_.size() + 1)); if (idx == 0) { return main_pred_.get(); diff --git a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt index 9795baf37fb..4e728dc74f7 100644 --- a/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt +++ b/paddle/fluid/inference/tensorrt/convert/CMakeLists.txt @@ -60,6 +60,7 @@ list( preln_skip_layernorm.cc roll_op.cc transformer_input_convert_op.cc + cast_op.cc remove_padding_op.cc recover_padding_op.cc preln_residual_bias.cc diff --git a/paddle/fluid/inference/tensorrt/convert/cast_op.cc b/paddle/fluid/inference/tensorrt/convert/cast_op.cc new file mode 100644 index 00000000000..18ea71fbf3b --- /dev/null +++ b/paddle/fluid/inference/tensorrt/convert/cast_op.cc @@ -0,0 +1,69 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/tensorrt/convert/op_converter.h" + +namespace paddle { +namespace framework { +class Scope; + +namespace proto { +class OpDesc; +} // namespace proto +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace inference { +namespace tensorrt { + +class CastOpConverter : public OpConverter { + public: + void operator()(const framework::proto::OpDesc& op, + const framework::Scope& scope, + bool test_mode) override { + VLOG(3) << "convert a cast op to tensorrt"; + framework::OpDesc op_desc(op, nullptr); + + auto* input = engine_->GetITensor(op_desc.Input("X")[0]); + auto out_dtype = BOOST_GET_CONST(int, op_desc.GetAttr("out_dtype")); + + auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input); + + switch (out_dtype) { + case 2: // INT32 = 2 + layer->getOutput(0)->setType(nvinfer1::DataType::kINT32); + break; + case 4: // FP16 = 4 + layer->getOutput(0)->setType(nvinfer1::DataType::kHALF); + break; + case 5: // FP32 = 5 + layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT); + break; + default: + LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype + << ") to a nvinfer DataType"; + break; + } + + auto output_name = op_desc.Output("Out")[0]; + RreplenishLayerAndOutput(layer, "cast", {output_name}, test_mode); + } +}; + +} // namespace tensorrt +} // namespace inference +} // namespace paddle + +REGISTER_TRT_OP_CONVERTER(cast, CastOpConverter); diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index f9086d7a822..66dcac02b45 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -55,7 +55,8 @@ struct SimpleOpTypeSetTeller : public Teller { #endif } - bool operator()(const std::string& op_type, const framework::OpDesc& desc, + bool operator()(const std::string& op_type, + const framework::OpDesc& desc, bool use_no_calib_int8) override { if (use_no_calib_int8) { return int8_teller_set.count(op_type); @@ -162,6 +163,7 @@ struct SimpleOpTypeSetTeller : public Teller { "c_allreduce_max", "c_allreduce_prod", "roll", + "cast", "preln_skip_layernorm", "transformer_input_convert", "recover_padding", @@ -265,6 +267,7 @@ struct SimpleOpTypeSetTeller : public Teller { "c_allreduce_max", "c_allreduce_prod", "roll", + "cast", "multiclass_nms3", "transformer_input_convert", "recover_padding", @@ -273,7 +276,8 @@ struct SimpleOpTypeSetTeller : public Teller { "unsqueeze2"}; }; -bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, +bool OpTeller::Tell(const framework::ir::Node* node, + bool use_no_calib_int8, bool with_dynamic_shape) { const std::string op_type = node->Op()->Type(); const framework::OpDesc desc = *node->Op(); @@ -818,8 +822,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "nearest_interp") { - std::vector attrs{"interp_method", "align_corners", "scale", - "out_h", "out_w"}; + std::vector attrs{ + "interp_method", "align_corners", "scale", "out_h", "out_w"}; for (auto const attr : attrs) { if (!desc.HasAttr(attr)) return false; } @@ -859,9 +863,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "nearest_interp_v2") { - std::vector attrs{"data_layout", "interp_method", - "align_corners", "scale", - "out_h", "out_w"}; + std::vector attrs{"data_layout", + "interp_method", + "align_corners", + "scale", + "out_h", + "out_w"}; for (auto const attr : attrs) { if (!desc.HasAttr(attr)) return false; } @@ -887,9 +894,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "bilinear_interp_v2") { - std::vector attrs{"data_layout", "interp_method", - "align_corners", "scale", - "out_h", "out_w"}; + std::vector attrs{"data_layout", + "interp_method", + "align_corners", + "scale", + "out_h", + "out_w"}; for (auto const attr : attrs) { if (!desc.HasAttr(attr)) { VLOG(3) << "The op_type " << op_type << " doesn't have the attr " @@ -1032,8 +1042,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } if (op_type == "batch_norm") { - const std::vector bn_inputs = {"X", "Bias", "Mean", "Scale", - "Variance"}; + const std::vector bn_inputs = { + "X", "Bias", "Mean", "Scale", "Variance"}; for (unsigned int i = 0; i < bn_inputs.size(); i++) { if (desc.Input(bn_inputs[i]).size() != 1) { VLOG(3) << "Invalid " << bn_inputs[i] @@ -1585,8 +1595,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, "the roi_align will change the batch size."; return false; } - std::vector attrs{"pooled_height", "pooled_width", - "spatial_scale", "sampling_ratio", + std::vector attrs{"pooled_height", + "pooled_width", + "spatial_scale", + "sampling_ratio", "aligned"}; for (auto const attr : attrs) { if (!desc.HasAttr(attr)) return false; @@ -1771,10 +1783,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, auto x_var_name = desc.Input("X")[0]; auto* x_var_desc = block->FindVar(x_var_name); const auto x_shape = x_var_desc->GetShape(); - int input_num = std::accumulate(x_shape.begin() + 1, x_shape.end(), 1, - std::multiplies()); - int shape_num = std::accumulate(shape.begin() + 1, shape.end(), 1, - std::multiplies()); + int input_num = std::accumulate( + x_shape.begin() + 1, x_shape.end(), 1, std::multiplies()); + int shape_num = std::accumulate( + shape.begin() + 1, shape.end(), 1, std::multiplies()); if (input_num == shape_num) { return true; } @@ -1960,6 +1972,23 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8, } } + if (op_type == "cast") { + int in_dtype = BOOST_GET_CONST(int, desc.GetAttr("in_dtype")); + int out_dtype = BOOST_GET_CONST(int, desc.GetAttr("out_dtype")); + if ((in_dtype == 4 || in_dtype == 5) && out_dtype == 4) { + VLOG(3) << "unsupport data type conversion"; + return false; + } + if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2 || + in_dtype == 0) && + (out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) { + VLOG(3) + << "only valid conversions are: " + "(kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)"; + return false; + } + } + if (op_type == "top_k_v2" || op_type == "top_k") { auto* block = desc.Block(); auto x_var_name = desc.Input("X")[0]; diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py new file mode 100644 index 00000000000..c381dbc2d6a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons +from program_config import TensorConfig, ProgramConfig +import unittest +import numpy as np +import paddle.inference as paddle_infer +from functools import partial +from typing import Optional, List, Callable, Dict, Any, Set + + +class TrtConvertCastTest(TrtLayerAutoScanTest): + + def is_program_valid(self, program_config: ProgramConfig) -> bool: + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4: + return False + if attrs[0]['in_dtype'] not in [ + 0, 2, 4, 5 + ] or attrs[0]['out_dtype'] not in [2, 4, 5]: + return False + return True + + def sample_program_configs(self): + + def generate_input(type): + if type == 0: + return np.ones([1, 3, 64, 64]).astype(np.bool) + elif type == 2: + return np.ones([1, 3, 64, 64]).astype(np.int32) + elif type == 4: + return np.ones([1, 3, 64, 64]).astype(np.float16) + else: + return np.ones([1, 3, 64, 64]).astype(np.float32) + + for in_dtype in [0, 2, 4, 5, 6]: + for out_dtype in [0, 2, 4, 5, 6]: + dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}] + + ops_config = [{ + "op_type": "cast", + "op_inputs": { + "X": ["input_data"] + }, + "op_outputs": { + "Out": ["cast_output_data"] + }, + "op_attrs": dics[0] + }] + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "input_data": + TensorConfig(data_gen=partial(generate_input, in_dtype)) + }, + outputs=["cast_output_data"]) + + yield program_config + + def sample_predictor_configs( + self, program_config) -> (paddle_infer.Config, List[int], float): + + def generate_dynamic_shape(attrs): + self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 64, 64]} + self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]} + self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]} + + def clear_dynamic_shape(): + self.dynamic_shape.min_input_shape = {} + self.dynamic_shape.max_input_shape = {} + self.dynamic_shape.opt_input_shape = {} + + def generate_trt_nodes_num(attrs, dynamic_shape): + return 1, 2 + + attrs = [ + program_config.ops[i].attrs for i in range(len(program_config.ops)) + ] + + # for static_shape + clear_dynamic_shape() + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, False), 1e-2 + + # for dynamic_shape + generate_dynamic_shape(attrs) + self.trt_param.precision = paddle_infer.PrecisionType.Float32 + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-5 + self.trt_param.precision = paddle_infer.PrecisionType.Half + yield self.create_inference_config(), generate_trt_nodes_num( + attrs, True), 1e-2 + + def test(self): + self.run_test() + + +if __name__ == "__main__": + unittest.main() -- GitLab