diff --git a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc index f765d9c22bbd5bf906c3f34f8486fc8e61f05fef..d3a0b4ede44e437bf7156ab1cec21deb272d38d0 100644 --- a/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc +++ b/paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc @@ -28,6 +28,7 @@ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/inference/analysis/helper.h" #include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h" +#include "paddle/fluid/inference/api/helper.h" #include "paddle/fluid/inference/tensorrt/convert/op_converter.h" #include "paddle/fluid/inference/tensorrt/engine.h" #include "paddle/fluid/inference/tensorrt/op_teller.h" @@ -74,7 +75,7 @@ void OutputProcess(framework::ir::Graph *graph, for (auto *var_node : op_node->outputs) { if (!trt_outputs.count(var_node)) continue; if (!var_node->Var()->Persistable() && - tensorrt::IsFloatVar(var_node->Var()->GetDataType()) && + IsFloatVar(var_node->Var()->GetDataType()) && var_node->Var()->GetDataType() != framework::proto::VarType::FP32) { for (auto *next_op : var_node->outputs) { // if next_op support mixed_precision, we need to add cast op. diff --git a/paddle/fluid/inference/api/CMakeLists.txt b/paddle/fluid/inference/api/CMakeLists.txt index 697daea542089d5b2d4cacba161f250708908993..6c6c18a88cb6fa1531960eef1b166e63bf4ee7d7 100755 --- a/paddle/fluid/inference/api/CMakeLists.txt +++ b/paddle/fluid/inference/api/CMakeLists.txt @@ -102,6 +102,11 @@ cc_test( SRCS api_tester.cc DEPS paddle_inference_api) +cc_test( + inference_api_helper_test + SRCS helper_test.cc + DEPS paddle_inference_api) + if(WITH_ONNXRUNTIME AND WIN32) # Copy onnxruntime for some c++ test in Windows, since the test will # be build only in CI, so suppose the generator in Windows is Ninja. diff --git a/paddle/fluid/inference/api/helper.h b/paddle/fluid/inference/api/helper.h index e3b145381280cd017498f9169173239118a3d134..de92281bb07a747860652e2167c176036871244f 100644 --- a/paddle/fluid/inference/api/helper.h +++ b/paddle/fluid/inference/api/helper.h @@ -81,6 +81,15 @@ inline PaddleDType ConvertToPaddleDType( } } +inline bool IsFloatVar(framework::proto::VarType::Type t) { + if (t == framework::proto::VarType::FP16 || + t == framework::proto::VarType::FP32 || + t == framework::proto::VarType::FP64 || + t == framework::proto::VarType::BF16) + return true; + return false; +} + using paddle::framework::DataTypeToString; // Timer for timer diff --git a/paddle/fluid/inference/api/helper_test.cc b/paddle/fluid/inference/api/helper_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..aae20f461128d1631d3dffa3c1d6fc30a8eca614 --- /dev/null +++ b/paddle/fluid/inference/api/helper_test.cc @@ -0,0 +1,35 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/inference/api/helper.h" + +#include "gtest/gtest.h" + +namespace paddle { + +TEST(inference_api_helper, DataType) { + ASSERT_TRUE( + paddle::inference::IsFloatVar(paddle::framework::proto::VarType::FP64)); + ASSERT_TRUE( + paddle::inference::IsFloatVar(paddle::framework::proto::VarType::FP32)); + ASSERT_TRUE( + paddle::inference::IsFloatVar(paddle::framework::proto::VarType::FP16)); + ASSERT_TRUE( + paddle::inference::IsFloatVar(paddle::framework::proto::VarType::BF16)); + + ASSERT_FALSE( + paddle::inference::IsFloatVar(paddle::framework::proto::VarType::INT32)); +} + +} // namespace paddle diff --git a/paddle/fluid/inference/tensorrt/helper.h b/paddle/fluid/inference/tensorrt/helper.h index 0b435c4c1214b2b79043c4af5aac65f13611249c..9d48024c61190e273fef608c651bf142314d34ec 100644 --- a/paddle/fluid/inference/tensorrt/helper.h +++ b/paddle/fluid/inference/tensorrt/helper.h @@ -22,7 +22,6 @@ #include #include -#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/platform/dynload/tensorrt.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/phi/common/data_type.h" @@ -215,14 +214,6 @@ static inline nvinfer1::DataType PhiType2NvType(phi::DataType type) { return nv_type; } -static bool IsFloatVar(framework::proto::VarType::Type t) { - if (t == framework::proto::VarType::FP16 || - t == framework::proto::VarType::FP32 || - t == framework::proto::VarType::FP64 || - t == framework::proto::VarType::BF16) - return true; - return false; -} } // namespace tensorrt } // namespace inference } // namespace paddle