未验证 提交 b6bf8994 编写于 作者: C ccrrong 提交者: GitHub

add cast trt converter (#43447)

* add cast trt converter
上级 8902a414
......@@ -104,7 +104,8 @@ bool IsPersistable(const framework::VarDesc *var) {
}
} // namespace
bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
bool PaddleTensorToLoDTensor(const PaddleTensor &pt,
framework::LoDTensor *t,
const platform::Place &place) {
framework::DDim ddim = phi::make_ddim(pt.shape);
void *input_ptr;
......@@ -132,18 +133,19 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
if (platform::is_cpu_place(place)) {
// TODO(panyx0718): Init LoDTensor from existing memcpy to save a copy.
std::memcpy(static_cast<void *>(input_ptr), pt.data.data(),
pt.data.length());
std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
} else if (platform::is_ipu_place(place)) {
#ifdef PADDLE_WITH_IPU
std::memcpy(static_cast<void *>(input_ptr), pt.data.data(),
pt.data.length());
std::memcpy(
static_cast<void *>(input_ptr), pt.data.data(), pt.data.length());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with WITH_IPU, should not reach here."));
#endif
} else if (platform::is_gpu_place(place)) {
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place), false,
PADDLE_ENFORCE_EQ(platform::is_xpu_place(place),
false,
platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU."));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -151,8 +153,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
auto *dev_ctx =
static_cast<const platform::CUDADeviceContext *>(pool.Get(place));
auto dst_gpu_place = place;
memory::Copy(dst_gpu_place, static_cast<void *>(input_ptr),
platform::CPUPlace(), pt.data.data(), pt.data.length(),
memory::Copy(dst_gpu_place,
static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length(),
dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
......@@ -161,8 +166,11 @@ bool PaddleTensorToLoDTensor(const PaddleTensor &pt, framework::LoDTensor *t,
} else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
auto dst_xpu_place = place;
memory::Copy(dst_xpu_place, static_cast<void *>(input_ptr),
platform::CPUPlace(), pt.data.data(), pt.data.length());
memory::Copy(dst_xpu_place,
static_cast<void *>(input_ptr),
platform::CPUPlace(),
pt.data.data(),
pt.data.length());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Not compile with XPU, should not reach here."));
......@@ -245,7 +253,8 @@ bool AnalysisPredictor::Init(
void AnalysisPredictor::InitPlace() {
if (config_.use_gpu()) {
PADDLE_ENFORCE_EQ(config_.use_xpu(), false,
PADDLE_ENFORCE_EQ(config_.use_xpu(),
false,
platform::errors::InvalidArgument(
"Only one choice can be made between CPU and XPU."));
place_ = paddle::platform::CUDAPlace(config_.gpu_device_id());
......@@ -502,7 +511,8 @@ static bool IsPrepareDataOptTargetOp(framework::OpDesc *op) {
}
static void DisablePrepareDataOpt(
std::shared_ptr<framework::ProgramDesc> inference_program, int block,
std::shared_ptr<framework::ProgramDesc> inference_program,
int block,
bool pre_disable_opt) {
bool disable_opt = false;
auto &infer_block = inference_program->Block(block);
......@@ -512,8 +522,8 @@ static void DisablePrepareDataOpt(
}
if (op->HasAttr("sub_block")) {
int blockID = op->GetBlockAttrId("sub_block");
DisablePrepareDataOpt(inference_program, blockID,
disable_opt || pre_disable_opt);
DisablePrepareDataOpt(
inference_program, blockID, disable_opt || pre_disable_opt);
}
// disable prepare data if unfriendly op is found
if (!disable_opt) {
......@@ -531,8 +541,8 @@ bool AnalysisPredictor::PrepareExecutor() {
#endif
DisablePrepareDataOpt(inference_program_, 0, false);
executor_->Prepare(sub_scope_, *inference_program_, 0,
config_.use_feed_fetch_ops_);
executor_->Prepare(
sub_scope_, *inference_program_, 0, config_.use_feed_fetch_ops_);
PADDLE_ENFORCE_NOT_NULL(sub_scope_,
platform::errors::PreconditionNotMet(
......@@ -578,8 +588,13 @@ bool AnalysisPredictor::PrepareFleetExecutor() {
feed_fetch_vars.emplace_back(pair.second);
}
fleet_exe_->Init(config_.dist_config().carrier_id(),
*(inference_program_.get()), scope_.get(), place_, 1,
{task_node_.get()}, id_to_rank, feed_fetch_vars);
*(inference_program_.get()),
scope_.get(),
place_,
1,
{task_node_.get()},
id_to_rank,
feed_fetch_vars);
return true;
}
......@@ -616,8 +631,12 @@ bool AnalysisPredictor::CommInit() {
peer_endpoints.emplace_back(
config_.dist_config().trainer_endpoints()[rank]);
}
InsertCommOp(var_name_base + std::to_string(order), ranks_in_group,
rank_in_group, peer_endpoints, comm_init_block, ring_id);
InsertCommOp(var_name_base + std::to_string(order),
ranks_in_group,
rank_in_group,
peer_endpoints,
comm_init_block,
ring_id);
order += 1;
}
framework::NaiveExecutor e(place_);
......@@ -629,8 +648,11 @@ bool AnalysisPredictor::CommInit() {
}
void AnalysisPredictor::InsertCommOp(
std::string tmp_var_name, int nranks, int rank,
const std::vector<std::string> &peer_endpoints, framework::BlockDesc *block,
std::string tmp_var_name,
int nranks,
int rank,
const std::vector<std::string> &peer_endpoints,
framework::BlockDesc *block,
int ring_id) {
/*
* tmp_var_name: the var name for var comm_id
......@@ -687,7 +709,8 @@ bool AnalysisPredictor::LoadConverterConfig(
<< config_.dist_config().comm_init_config() << "\n";
std::ifstream fin(config_.dist_config().comm_init_config(), std::ios::in);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true,
static_cast<bool>(fin.is_open()),
true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.",
config_.dist_config().comm_init_config()));
......@@ -831,8 +854,9 @@ bool AnalysisPredictor::Run(const std::vector<PaddleTensor> &inputs,
timer.tic();
// set feed variable
framework::Scope *scope = sub_scope_ ? sub_scope_ : scope_.get();
PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::PreconditionNotMet(
"The scope should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::PreconditionNotMet("The scope should not be nullptr."));
if (!SetFeed(inputs, scope)) {
LOG(ERROR) << "fail to set feed";
return false;
......@@ -935,9 +959,11 @@ bool AnalysisPredictor::GetFetch(std::vector<PaddleTensor> *outputs,
for (size_t i = 0; i < fetches_.size(); ++i) {
int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col"));
PADDLE_ENFORCE_EQ(
static_cast<size_t>(idx), i,
static_cast<size_t>(idx),
i,
platform::errors::InvalidArgument(
"Fetch op's col attr(%d) should be equal to the index(%d)", idx,
"Fetch op's col attr(%d) should be equal to the index(%d)",
idx,
i));
framework::FetchType &fetch_var =
framework::GetFetchVariable(*scope, "fetch", idx);
......@@ -978,7 +1004,8 @@ void AnalysisPredictor::PrepareArgument() {
if (!config_.model_dir().empty()) {
argument_.SetModelDir(config_.model_dir());
} else {
PADDLE_ENFORCE_EQ(config_.prog_file().empty(), false,
PADDLE_ENFORCE_EQ(config_.prog_file().empty(),
false,
platform::errors::PreconditionNotMet(
"Either model_dir or prog_file should be set."));
std::string dir = inference::analysis::GetDirRoot(config_.prog_file());
......@@ -1123,7 +1150,8 @@ void AnalysisPredictor::OptimizeInferenceProgram() {
Analyzer().Run(&argument_);
PADDLE_ENFORCE_EQ(
argument_.scope_valid(), true,
argument_.scope_valid(),
true,
platform::errors::InvalidArgument("The argument scope should be valid."));
VLOG(5) << "to prepare executor";
ARGUMENT_CHECK_FIELD((&argument_), ir_analyzed_program);
......@@ -1173,7 +1201,8 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
}
VLOG(3) << "create AnalysisConfig";
PADDLE_ENFORCE_EQ(
config.is_valid(), true,
config.is_valid(),
true,
platform::errors::InvalidArgument(
"Note: Each config can only be used for one predictor."));
......@@ -1190,11 +1219,13 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
std::call_once(gflags_initialized, [&]() {
std::vector<std::string> gflags;
PADDLE_ENFORCE_GE(
config.memory_pool_init_size_mb(), 0.f,
config.memory_pool_init_size_mb(),
0.f,
platform::errors::InvalidArgument(
"The size of memory pool should be greater than 0."));
PADDLE_ENFORCE_GE(
config.gpu_device_id(), 0,
config.gpu_device_id(),
0,
platform::errors::InvalidArgument(
"Invalid device id (%d). The device id should be greater than 0.",
config.gpu_device_id()));
......@@ -1303,8 +1334,9 @@ void AnalysisPredictor::PrepareFeedFetch() {
}
void AnalysisPredictor::CreateFeedFetchVar(framework::Scope *scope) {
PADDLE_ENFORCE_NOT_NULL(scope, platform::errors::InvalidArgument(
"The scope should not be nullptr."));
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::InvalidArgument("The scope should not be nullptr."));
auto *var = scope->Var("feed");
var->GetMutable<framework::FeedList>();
var = scope->Var("fetch");
......@@ -1325,8 +1357,9 @@ AnalysisPredictor::GetInputTensorShape() {
std::vector<std::string> names = GetInputNames();
for (std::string name : names) {
auto *var = inference_program_->Block(0).FindVar(name);
PADDLE_ENFORCE_NOT_NULL(var, platform::errors::PreconditionNotMet(
"Input %s does not exist.", name));
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::PreconditionNotMet("Input %s does not exist.", name));
input_shapes[name] = var->GetShape();
}
return input_shapes;
......@@ -1565,7 +1598,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
std::vector<std::pair<int32_t, int32_t>> counter;
for (auto &it : m) counter.push_back(it);
std::sort(
counter.begin(), counter.end(),
counter.begin(),
counter.end(),
[](std::pair<int32_t, int32_t> &a, std::pair<int32_t, int32_t> &b) {
return a.second > b.second;
});
......@@ -1587,8 +1621,8 @@ void AnalysisPredictor::StatisticShapeRangeInfo() {
opt_shapes[name] = opt_shape;
}
inference::SerializeShapeRangeInfo(config_.shape_range_info_path(),
min_shapes, max_shapes, opt_shapes);
inference::SerializeShapeRangeInfo(
config_.shape_range_info_path(), min_shapes, max_shapes, opt_shapes);
}
bool AnalysisPredictor::LoadProgramDesc() {
......@@ -1608,7 +1642,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
return false;
}
LOG(ERROR) << string::Sprintf(
"not valid model path '%s' or program path '%s'.", config_.model_dir(),
"not valid model path '%s' or program path '%s'.",
config_.model_dir(),
config_.params_file());
return false;
}
......@@ -1620,7 +1655,8 @@ bool AnalysisPredictor::LoadProgramDesc() {
// Read binary
std::ifstream fin(filename, std::ios::in | std::ios::binary);
PADDLE_ENFORCE_EQ(
static_cast<bool>(fin.is_open()), true,
static_cast<bool>(fin.is_open()),
true,
platform::errors::NotFound(
"Cannot open file %s, please confirm whether the file is normal.",
filename));
......@@ -1722,7 +1758,8 @@ void AnalysisPredictor::ClearIntermediateTensor() {
#if PADDLE_WITH_TENSORRT
bool AnalysisPredictor::SaveTrtCalibToDisk() {
PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(), true,
PADDLE_ENFORCE_EQ(config_.tensorrt_engine_enabled(),
true,
platform::errors::PreconditionNotMet(
"This func can be invoked only in trt mode"));
auto &block = inference_program_->Block(0);
......@@ -1963,6 +2000,7 @@ USE_TRT_CONVERTER(c_allreduce_sum)
USE_TRT_CONVERTER(roll)
USE_TRT_CONVERTER(strided_slice)
USE_TRT_CONVERTER(transformer_input_convert)
USE_TRT_CONVERTER(cast)
USE_TRT_CONVERTER(recover_padding)
USE_TRT_CONVERTER(remove_padding)
USE_TRT_CONVERTER(top_k)
......@@ -1990,8 +2028,10 @@ Predictor::Predictor(const Config &config) {
<< "Paddle2ONNX do't support convert the Model, fall back to using "
"Paddle Inference.";
} else {
predictor_ = paddle::CreatePaddlePredictor<
Config, paddle::PaddleEngineKind::kONNXRuntime>(config);
predictor_ =
paddle::CreatePaddlePredictor<Config,
paddle::PaddleEngineKind::kONNXRuntime>(
config);
return;
}
#else
......@@ -2001,8 +2041,10 @@ Predictor::Predictor(const Config &config) {
"fall back to using Paddle Inference.";
#endif
}
predictor_ = paddle::CreatePaddlePredictor<
Config, paddle::PaddleEngineKind::kAnalysis>(config);
predictor_ =
paddle::CreatePaddlePredictor<Config,
paddle::PaddleEngineKind::kAnalysis>(
config);
}
std::vector<std::string> Predictor::GetInputNames() {
......@@ -2086,7 +2128,8 @@ std::shared_ptr<Predictor> CreatePredictor(const Config &config) { // NOLINT
namespace services {
PredictorPool::PredictorPool(const Config &config, size_t size) {
PADDLE_ENFORCE_GE(
size, 1UL,
size,
1UL,
paddle::platform::errors::InvalidArgument(
"The predictor pool size should be greater than 1, but it's (%d)",
size));
......@@ -2105,9 +2148,11 @@ PredictorPool::PredictorPool(const Config &config, size_t size) {
Predictor *PredictorPool::Retrive(size_t idx) {
PADDLE_ENFORCE_LT(
idx, preds_.size() + 1,
idx,
preds_.size() + 1,
paddle::platform::errors::InvalidArgument(
"There are (%d) predictors in the pool, but the idx is (%d)", idx,
"There are (%d) predictors in the pool, but the idx is (%d)",
idx,
preds_.size() + 1));
if (idx == 0) {
return main_pred_.get();
......
......@@ -60,6 +60,7 @@ list(
preln_skip_layernorm.cc
roll_op.cc
transformer_input_convert_op.cc
cast_op.cc
remove_padding_op.cc
recover_padding_op.cc
preln_residual_bias.cc
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
namespace paddle {
namespace framework {
class Scope;
namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
namespace paddle {
namespace inference {
namespace tensorrt {
class CastOpConverter : public OpConverter {
public:
void operator()(const framework::proto::OpDesc& op,
const framework::Scope& scope,
bool test_mode) override {
VLOG(3) << "convert a cast op to tensorrt";
framework::OpDesc op_desc(op, nullptr);
auto* input = engine_->GetITensor(op_desc.Input("X")[0]);
auto out_dtype = BOOST_GET_CONST(int, op_desc.GetAttr("out_dtype"));
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Identity, *input);
switch (out_dtype) {
case 2: // INT32 = 2
layer->getOutput(0)->setType(nvinfer1::DataType::kINT32);
break;
case 4: // FP16 = 4
layer->getOutput(0)->setType(nvinfer1::DataType::kHALF);
break;
case 5: // FP32 = 5
layer->getOutput(0)->setType(nvinfer1::DataType::kFLOAT);
break;
default:
LOG(ERROR) << "Unable to convert a fluid data type(" << out_dtype
<< ") to a nvinfer DataType";
break;
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "cast", {output_name}, test_mode);
}
};
} // namespace tensorrt
} // namespace inference
} // namespace paddle
REGISTER_TRT_OP_CONVERTER(cast, CastOpConverter);
......@@ -55,7 +55,8 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif
}
bool operator()(const std::string& op_type, const framework::OpDesc& desc,
bool operator()(const std::string& op_type,
const framework::OpDesc& desc,
bool use_no_calib_int8) override {
if (use_no_calib_int8) {
return int8_teller_set.count(op_type);
......@@ -162,6 +163,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"c_allreduce_max",
"c_allreduce_prod",
"roll",
"cast",
"preln_skip_layernorm",
"transformer_input_convert",
"recover_padding",
......@@ -265,6 +267,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"c_allreduce_max",
"c_allreduce_prod",
"roll",
"cast",
"multiclass_nms3",
"transformer_input_convert",
"recover_padding",
......@@ -273,7 +276,8 @@ struct SimpleOpTypeSetTeller : public Teller {
"unsqueeze2"};
};
bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
bool OpTeller::Tell(const framework::ir::Node* node,
bool use_no_calib_int8,
bool with_dynamic_shape) {
const std::string op_type = node->Op()->Type();
const framework::OpDesc desc = *node->Op();
......@@ -818,8 +822,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
if (op_type == "nearest_interp") {
std::vector<std::string> attrs{"interp_method", "align_corners", "scale",
"out_h", "out_w"};
std::vector<std::string> attrs{
"interp_method", "align_corners", "scale", "out_h", "out_w"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false;
}
......@@ -859,9 +863,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
if (op_type == "nearest_interp_v2") {
std::vector<std::string> attrs{"data_layout", "interp_method",
"align_corners", "scale",
"out_h", "out_w"};
std::vector<std::string> attrs{"data_layout",
"interp_method",
"align_corners",
"scale",
"out_h",
"out_w"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false;
}
......@@ -887,9 +894,12 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
if (op_type == "bilinear_interp_v2") {
std::vector<std::string> attrs{"data_layout", "interp_method",
"align_corners", "scale",
"out_h", "out_w"};
std::vector<std::string> attrs{"data_layout",
"interp_method",
"align_corners",
"scale",
"out_h",
"out_w"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) {
VLOG(3) << "The op_type " << op_type << " doesn't have the attr "
......@@ -1032,8 +1042,8 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
if (op_type == "batch_norm") {
const std::vector<std::string> bn_inputs = {"X", "Bias", "Mean", "Scale",
"Variance"};
const std::vector<std::string> bn_inputs = {
"X", "Bias", "Mean", "Scale", "Variance"};
for (unsigned int i = 0; i < bn_inputs.size(); i++) {
if (desc.Input(bn_inputs[i]).size() != 1) {
VLOG(3) << "Invalid " << bn_inputs[i]
......@@ -1585,8 +1595,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
"the roi_align will change the batch size.";
return false;
}
std::vector<std::string> attrs{"pooled_height", "pooled_width",
"spatial_scale", "sampling_ratio",
std::vector<std::string> attrs{"pooled_height",
"pooled_width",
"spatial_scale",
"sampling_ratio",
"aligned"};
for (auto const attr : attrs) {
if (!desc.HasAttr(attr)) return false;
......@@ -1771,10 +1783,10 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
auto x_var_name = desc.Input("X")[0];
auto* x_var_desc = block->FindVar(x_var_name);
const auto x_shape = x_var_desc->GetShape();
int input_num = std::accumulate(x_shape.begin() + 1, x_shape.end(), 1,
std::multiplies<int>());
int shape_num = std::accumulate(shape.begin() + 1, shape.end(), 1,
std::multiplies<int>());
int input_num = std::accumulate(
x_shape.begin() + 1, x_shape.end(), 1, std::multiplies<int>());
int shape_num = std::accumulate(
shape.begin() + 1, shape.end(), 1, std::multiplies<int>());
if (input_num == shape_num) {
return true;
}
......@@ -1960,6 +1972,23 @@ bool OpTeller::Tell(const framework::ir::Node* node, bool use_no_calib_int8,
}
}
if (op_type == "cast") {
int in_dtype = BOOST_GET_CONST(int, desc.GetAttr("in_dtype"));
int out_dtype = BOOST_GET_CONST(int, desc.GetAttr("out_dtype"));
if ((in_dtype == 4 || in_dtype == 5) && out_dtype == 4) {
VLOG(3) << "unsupport data type conversion";
return false;
}
if (!((in_dtype == 5 || in_dtype == 4 || in_dtype == 2 ||
in_dtype == 0) &&
(out_dtype == 5 || out_dtype == 4 || out_dtype == 2))) {
VLOG(3)
<< "only valid conversions are: "
"(kFLOAT | kHALF | kINT32 | kBOOL) -> (kFLOAT | kHALF | kINT32)";
return false;
}
}
if (op_type == "top_k_v2" || op_type == "top_k") {
auto* block = desc.Block();
auto x_var_name = desc.Input("X")[0];
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from trt_layer_auto_scan_test import TrtLayerAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig
import unittest
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
class TrtConvertCastTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
if attrs[0]['in_dtype'] in [4, 5] and attrs[0]['out_dtype'] == 4:
return False
if attrs[0]['in_dtype'] not in [
0, 2, 4, 5
] or attrs[0]['out_dtype'] not in [2, 4, 5]:
return False
return True
def sample_program_configs(self):
def generate_input(type):
if type == 0:
return np.ones([1, 3, 64, 64]).astype(np.bool)
elif type == 2:
return np.ones([1, 3, 64, 64]).astype(np.int32)
elif type == 4:
return np.ones([1, 3, 64, 64]).astype(np.float16)
else:
return np.ones([1, 3, 64, 64]).astype(np.float32)
for in_dtype in [0, 2, 4, 5, 6]:
for out_dtype in [0, 2, 4, 5, 6]:
dics = [{"in_dtype": in_dtype, "out_dtype": out_dtype}]
ops_config = [{
"op_type": "cast",
"op_inputs": {
"X": ["input_data"]
},
"op_outputs": {
"Out": ["cast_output_data"]
},
"op_attrs": dics[0]
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"input_data":
TensorConfig(data_gen=partial(generate_input, in_dtype))
},
outputs=["cast_output_data"])
yield program_config
def sample_predictor_configs(
self, program_config) -> (paddle_infer.Config, List[int], float):
def generate_dynamic_shape(attrs):
self.dynamic_shape.min_input_shape = {"input_data": [1, 3, 64, 64]}
self.dynamic_shape.max_input_shape = {"input_data": [4, 3, 64, 64]}
self.dynamic_shape.opt_input_shape = {"input_data": [1, 3, 64, 64]}
def clear_dynamic_shape():
self.dynamic_shape.min_input_shape = {}
self.dynamic_shape.max_input_shape = {}
self.dynamic_shape.opt_input_shape = {}
def generate_trt_nodes_num(attrs, dynamic_shape):
return 1, 2
attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops))
]
# for static_shape
clear_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, False), 1e-2
# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), generate_trt_nodes_num(
attrs, True), 1e-2
def test(self):
self.run_test()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册